Skip to content

Commit c3cd279

Browse files
committed
Fix reshape concat with convert
1 parent d6e5d9c commit c3cd279

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -849,9 +849,9 @@ struct ConcatAppendingReshape final
849849
size_t frontSize = 0;
850850
for (auto v : op.getOperands()) {
851851
if (auto t = v.getDefiningOp<stablehlo::ConvertOp>()) {
852-
converts.push_back(
853-
t.getType().cast<RankedTensorType>().getElementType());
854852
v = t.getOperand();
853+
converts.push_back(
854+
v.getType().cast<RankedTensorType>().getElementType());
855855
} else
856856
converts.push_back(nullptr);
857857
if (auto t = v.getDefiningOp<stablehlo::ReshapeOp>()) {
@@ -906,11 +906,8 @@ struct ConcatAppendingReshape final
906906
lhs2);
907907

908908
if (typeconvert)
909-
res2 = rewriter.create<stablehlo::ConvertOp>(
910-
op.getLoc(),
911-
RankedTensorType::get(
912-
res2.getType().cast<RankedTensorType>().getShape(), typeconvert),
913-
res2);
909+
res2 = rewriter.create<stablehlo::ConvertOp>(op.getLoc(), op.getType(),
910+
res2);
914911

915912
rewriter.replaceOp(op, res2);
916913
return success();

test/lit_tests/concatreshape.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,17 @@ module {
99
return %concat : tensor<2x3x4xf32>
1010
}
1111

12+
func.func @main2(%a : tensor<3x4xf32>, %b : tensor<3x4xf32>) -> tensor<2x3x4xf64> {
13+
%u = stablehlo.reshape %a : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
14+
%uc = stablehlo.convert %u : (tensor<1x3x4xf32>) -> tensor<1x3x4xf64>
15+
%v = stablehlo.reshape %b : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
16+
%vc = stablehlo.convert %v : (tensor<1x3x4xf32>) -> tensor<1x3x4xf64>
17+
%concat = stablehlo.concatenate %uc, %vc, dim=0 : (tensor<1x3x4xf64>, tensor<1x3x4xf64>) -> tensor<2x3x4xf64>
18+
return %concat : tensor<2x3x4xf64>
19+
}
20+
1221
// TODO this opt
13-
func.func @main2(%a : tensor<3x4xf32>, %b : tensor<3x4xf32>) -> tensor<3x2x4xf32> {
22+
func.func @main3(%a : tensor<3x4xf32>, %b : tensor<3x4xf32>) -> tensor<3x2x4xf32> {
1423
%u = stablehlo.reshape %a : (tensor<3x4xf32>) -> tensor<3x1x4xf32>
1524
%v = stablehlo.reshape %b : (tensor<3x4xf32>) -> tensor<3x1x4xf32>
1625
%concat = stablehlo.concatenate %u, %v, dim=1 : (tensor<3x1x4xf32>, tensor<3x1x4xf32>) -> tensor<3x2x4xf32>
@@ -24,3 +33,11 @@ module {
2433
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<6x4xf32>) -> tensor<2x3x4xf32>
2534
// CHECK-NEXT: return %1 : tensor<2x3x4xf32>
2635
// CHECK-NEXT: }
36+
37+
// CHECK: func.func @main2(%arg0: tensor<3x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x3x4xf64> {
38+
// CHECK-NEXT: %0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xf64>
39+
// CHECK-NEXT: %1 = stablehlo.convert %arg1 : (tensor<3x4xf32>) -> tensor<3x4xf64>
40+
// CHECK-NEXT: %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<3x4xf64>, tensor<3x4xf64>) -> tensor<6x4xf64>
41+
// CHECK-NEXT: %3 = stablehlo.reshape %2 : (tensor<6x4xf64>) -> tensor<2x3x4xf64>
42+
// CHECK-NEXT: return %3 : tensor<2x3x4xf64>
43+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)