Skip to content

Commit acafb2e

Browse files
LukeBoyertensorflower-gardener
authored andcommitted
Refactor/cleanup hlo -> tfl pass
* Make the convert dot general function an explicit pattern class and register in cc * Remove unused functions * Remove redundant commentary PiperOrigin-RevId: 649720051
1 parent 5d20819 commit acafb2e

File tree

4 files changed

+30
-60
lines changed

4 files changed

+30
-60
lines changed

tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
// This file implements logic for legalizing mhlo.dot_general to
1717
// tflite.batch_matmul.
18+
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h"
1819

1920
#include <cstddef>
2021
#include <cstdint>
@@ -39,6 +40,7 @@ limitations under the License.
3940
#include "mlir/IR/Value.h" // from @llvm-project
4041
#include "mlir/IR/ValueRange.h" // from @llvm-project
4142
#include "mlir/Support/LLVM.h" // from @llvm-project
43+
#include "mlir/Support/LogicalResult.h" // from @llvm-project
4244
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
4345
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
4446
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
@@ -408,30 +410,15 @@ Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs,
408410
return reshaped.getResult();
409411
}
410412

411-
// Converts mhlo.dot_general to tfl.BatchMatMul. Reshape and Transpose ops will
412-
// be inserted when necessary. See ConvertDotGeneralOp for additional notes.
413-
Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) {
414-
auto dot_op = cast<mhlo::DotOp>(old_op);
415-
auto lhs_rank = mlir::cast<ShapedType>(dot_op.getLhs().getType()).getRank();
416-
auto dot_dimension_numbers =
417-
mhlo::DotDimensionNumbersAttr::get(rewriter.getContext(),
418-
/*lhsBatchingDimensions=*/{},
419-
/*rhsBatchingDimensions=*/{},
420-
/*lhsContractingDimensions=*/
421-
{lhs_rank == 1 ? 0 : 1},
422-
/*rhsContractingDimensions=*/{0});
423-
return ConvertDot(
424-
rewriter, dot_op.getLhs(), dot_op.getRhs(), dot_dimension_numbers,
425-
mlir::cast<ShapedType>(dot_op.getResult().getType()), dot_op.getLoc());
413+
LogicalResult LowerDotGeneralOp::matchAndRewrite(
414+
mhlo::DotGeneralOp op, OpAdaptor adaptor,
415+
ConversionPatternRewriter& rewriter) const {
416+
auto val = ConvertDot(
417+
rewriter, op.getLhs(), op.getRhs(), op.getDotDimensionNumbers(),
418+
mlir::cast<ShapedType>(op.getResult().getType()), op.getLoc());
419+
rewriter.replaceOp(op, val.getDefiningOp());
420+
return mlir::success();
426421
}
427422

428-
Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op) {
429-
auto dot_general_op = cast<mhlo::DotGeneralOp>(old_op);
430-
return ConvertDot(
431-
rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(),
432-
dot_general_op.getDotDimensionNumbers(),
433-
mlir::cast<ShapedType>(dot_general_op.getResult().getType()),
434-
dot_general_op.getLoc());
435-
}
436423
} // namespace odml
437424
} // namespace mlir

tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@ limitations under the License.
1818
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_DOT_GENERAL_H_
1919
#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_DOT_GENERAL_H_
2020

21-
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22-
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
23-
#include "mlir/IR/Operation.h" // from @llvm-project
24-
#include "mlir/IR/PatternMatch.h" // from @llvm-project
2521
#include "mlir/IR/Value.h" // from @llvm-project
2622
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
23+
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
2724

2825
namespace mlir {
2926
namespace odml {
@@ -41,7 +38,14 @@ namespace odml {
4138
// Concat ) are inserted for shape inference purposes.
4239
// 4) All the DotOp are converted to DotGeneral during the optimization pass
4340
// (ConvertDotOp).
44-
Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op);
41+
class LowerDotGeneralOp : public OpConversionPattern<mhlo::DotGeneralOp> {
42+
public:
43+
using OpConversionPattern::OpConversionPattern;
44+
45+
LogicalResult matchAndRewrite(
46+
mhlo::DotGeneralOp op, OpAdaptor adaptor,
47+
ConversionPatternRewriter& rewriter) const final;
48+
};
4549
} // namespace odml
4650
} // namespace mlir
4751

tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,22 @@ std::optional<bool> IsCbrtLegal(mhlo::CbrtOp op) {
109109

110110
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_tflite_legalize_hlo.inc"
111111
void LegalizeHloToTfLitePass::runOnOperation() {
112-
MLIRContext& context = getContext();
113-
RewritePatternSet patterns(&getContext());
114-
// Add new conversion patterns here.
115-
PopulateLegalizeHloToTFLitePatterns(&patterns, &context);
112+
MLIRContext* context = &getContext();
116113

117-
ConversionTarget target(context);
114+
RewritePatternSet patterns(context);
115+
patterns.add<odml::ConvertCustomCallOp, odml::LowerDotGeneralOp,
116+
ConvertReduceOpToTFLiteArgmin, ConvertReduceOpToTFLiteArgmax>(
117+
context);
118+
populateWithGenerated(patterns);
119+
120+
ConversionTarget target(*context);
118121
target.addLegalDialect<TFL::TensorFlowLiteDialect, mhlo::MhloDialect>();
119122
target.addLegalOp<func::CallOp, func::ConstantOp, arith::ConstantOp>();
120123
target.addDynamicallyLegalOp<mhlo::CustomCallOp>(IsCustomCallLegal);
121124
target.addDynamicallyLegalOp<mhlo::ReduceOp>(IsReduceOpLegal);
122-
// Converted MHLO ops should be marked illegal here.
123-
// TODO: b/304003568 - Add TF_TransposeOp folding logic to tflite.
124125
target.addDynamicallyLegalOp<mhlo::CbrtOp>(IsCbrtLegal);
125126
target.addIllegalOp<mhlo::DotGeneralOp, mhlo::DotOp, mhlo::TransposeOp>();
127+
126128
if (failed(applyPartialConversion(getOperation(), target,
127129
std::move(patterns)))) {
128130
getOperation().emitError("mhlo to TFLite legalization failed.");
@@ -131,14 +133,6 @@ void LegalizeHloToTfLitePass::runOnOperation() {
131133
}
132134
} // namespace
133135

134-
void PopulateLegalizeHloToTFLitePatterns(RewritePatternSet* patterns,
135-
MLIRContext* context) {
136-
patterns->add<odml::ConvertCustomCallOp>(context);
137-
populateWithGenerated(*patterns);
138-
139-
patterns->add<ConvertReduceOpToTFLiteArgmin, ConvertReduceOpToTFLiteArgmax>(
140-
context);
141-
}
142136

143137
// Creates an instance of the pass.
144138
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeHloToTfLitePass() {

tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
// This is the legalization pattern definition file for HLO to TFL.
17-
1816
include "mlir/IR/OpBase.td"
1917
include "mlir/Dialect/Func/IR/FuncOps.td"
2018
include "mhlo/IR/hlo_ops.td"
@@ -23,30 +21,17 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
2321
include "tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td"
2422
include "mlir/Dialect/Arith/IR/ArithOps.td"
2523

26-
27-
2824
def CreateTFLCastToInt32Op : NativeCodeCall<
2925
"CreateCastToInt32($0, $_loc, $_builder)">;
30-
// TODO: b/304003568 - Add TF_TransposeOp folding logic to tflite.
26+
3127
def LegalizeTranspose : Pat<(MHLO_TransposeOp $arg, $perm),
3228
(TFL_TransposeOp $arg,
3329
(CreateTFLCastToInt32Op (TFL_ConstOp $perm)))>;
3430

35-
36-
def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, "
37-
"$0.getDefiningOp())">;
38-
39-
def LegalizeDotGeneral: Pat<(MHLO_DotGeneralOp:$old_value
40-
$lhs,
41-
$rhs,
42-
$dot_dimension_numbers, $precision_config),
43-
(ConvertDotGeneralOp $old_value)>;
44-
45-
def LowerCbrt : Pat<(MHLO_CbrtOp $opr),
31+
def LowerCbrt : Pat<(MHLO_CbrtOp $opr),
4632
(TFL_PowOp $opr,
4733
(TFL_DivOp
4834
(Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "1.0f">),
4935
(Arith_ConstantOp ConstantAttr<RankedF32ElementsAttr<[]>, "3.0f">),
5036
TFL_AF_None)),
5137
[(F32Tensor $opr)]>;
52-

0 commit comments

Comments
 (0)