Skip to content
This repository was archived by the owner on Jan 30, 2025. It is now read-only.

Commit e8b4efd

Browse files
committed
support lowering or aten.log1p
1 parent 57d5e00 commit e8b4efd

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

lib/Conversion/TorchToTcp/Elementwise.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,37 @@ class ConvertAtenSqrtOp : public OpConversionPattern<AtenSqrtOp> {
478478
}
479479
};
480480

481+
class ConvertAtenLog1pOp : public OpConversionPattern<AtenLog1pOp> {
482+
public:
483+
using OpConversionPattern::OpConversionPattern;
484+
485+
LogicalResult
486+
matchAndRewrite(AtenLog1pOp op, OpAdaptor adaptor,
487+
ConversionPatternRewriter &rewriter) const override {
488+
Value input = adaptor.getSelf();
489+
RankedTensorType inputType = input.getType().dyn_cast<RankedTensorType>();
490+
491+
if (!inputType)
492+
return rewriter.notifyMatchFailure(
493+
op, "Only Ranked Tensor types are supported in TCP");
494+
495+
auto elementType = inputType.getElementType();
496+
if (!isa<mlir::FloatType>(elementType))
497+
return rewriter.notifyMatchFailure(
498+
op, "Only floating-point datatype is supported");
499+
500+
auto constOp = torch_to_tcp::getConstTensor<float>(
501+
rewriter, op, llvm::ArrayRef((float)1.0), {})
502+
.value();
503+
constOp = torch_to_tcp::broadcast0DOr1DToNDAndMatchShape(
504+
rewriter, constOp, input, elementType);
505+
auto addOp =
506+
rewriter.create<tcp::AddOp>(op.getLoc(), inputType, input, constOp);
507+
rewriter.replaceOpWithNewOp<tcp::LogOp>(op, inputType, addOp);
508+
return success();
509+
}
510+
};
511+
481512
template <typename AtenOpT, typename TcpOpT>
482513
class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern<AtenOpT> {
483514
public:
@@ -694,6 +725,7 @@ void torch_to_tcp::populateElementwisePatternsAndLegality(
694725
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenBatchNormOp);
695726
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenAtan2Op);
696727
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenSqrtOp);
728+
INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenLog1pOp);
697729
#undef INSERT_ATEN_ELEMENTWISE_OP_PATTERN
698730

699731
#define INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenOp, TcpOp) \

test/Conversion/TorchToTcp/elementwise.mlir

+19
Original file line numberDiff line numberDiff line change
@@ -762,3 +762,22 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.vtenso
762762
%0 = torch.aten.to.dtype %arg0, %int11, %false, %false, %none : !torch.vtensor<[?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],ui8>
763763
return %0 : !torch.vtensor<[?,?],ui8>
764764
}
765+
766+
// -----
767+
768+
// CHECK-LABEL: func.func @torch.aten.log1p(
769+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> {
770+
// CHECK-DAG: %[[TO_BUILTIN0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4,19,2],f32> -> tensor<?x4x19x2xf32>
771+
// CHECK: %[[CONST:.*]] = tcp.const {value = dense<1.000000e+00> : tensor<f32>} : tensor<f32>
772+
// CHECK: %[[EXPAND_SHAPE:.*]] = tensor.expand_shape %[[CONST]] [] output_shape [1, 1, 1, 1] : tensor<f32> into tensor<1x1x1x1xf32>
773+
// CHECK: %[[CONST0:.*]] = arith.constant 0 : index
774+
// CHECK: %[[DIM0:.*]] = tensor.dim %[[TO_BUILTIN0]], %[[CONST0]] : tensor<?x4x19x2xf32>
775+
// CHECK: %[[BROADCAST:.*]] = tcp.broadcast %[[EXPAND_SHAPE]], %[[DIM0]]
776+
// CHECK: %[[ADD:.*]] = tcp.add %[[TO_BUILTIN0]], %[[BROADCAST]] : tensor<?x4x19x2xf32>, tensor<?x4x19x2xf32> -> tensor<?x4x19x2xf32>
777+
// CHECK: %[[LOG:.*]] = tcp.log %[[ADD]] : tensor<?x4x19x2xf32> -> tensor<?x4x19x2xf32>
778+
// CHECK: %[[FROM_BUILTIN:.*]] = torch_c.from_builtin_tensor %[[LOG]] : tensor<?x4x19x2xf32> -> !torch.vtensor<[?,4,19,2],f32>
779+
// CHECK: return %[[FROM_BUILTIN]] : !torch.vtensor<[?,4,19,2],f32>
780+
func.func @torch.aten.log1p(%arg0: !torch.vtensor<[?,4,19,2],f32>) -> !torch.vtensor<[?,4,19,2],f32> {
781+
%1 = torch.aten.log1p %arg0 : !torch.vtensor<[?,4,19,2],f32> -> !torch.vtensor<[?,4,19,2],f32>
782+
return %1 : !torch.vtensor<[?,4,19,2],f32>
783+
}

0 commit comments

Comments
 (0)