@@ -9,8 +9,17 @@ module {
9
9
return %concat : tensor <2 x3 x4 xf32 >
10
10
}
11
11
12
+ func.func @main2 (%a : tensor <3 x4 xf32 >, %b : tensor <3 x4 xf32 >) -> tensor <2 x3 x4 xf64 > {
13
+ %u = stablehlo.reshape %a : (tensor <3 x4 xf32 >) -> tensor <1 x3 x4 xf32 >
14
+ %uc = stablehlo.convert %u : (tensor <1 x3 x4 xf32 >) -> tensor <1 x3 x4 xf64 >
15
+ %v = stablehlo.reshape %b : (tensor <3 x4 xf32 >) -> tensor <1 x3 x4 xf32 >
16
+ %vc = stablehlo.convert %v : (tensor <1 x3 x4 xf32 >) -> tensor <1 x3 x4 xf64 >
17
+ %concat = stablehlo.concatenate %uc , %vc , dim =0 : (tensor <1 x3 x4 xf64 >, tensor <1 x3 x4 xf64 >) -> tensor <2 x3 x4 xf64 >
18
+ return %concat : tensor <2 x3 x4 xf64 >
19
+ }
20
+
12
21
// TODO this opt
13
- func.func @main2 (%a : tensor <3 x4 xf32 >, %b : tensor <3 x4 xf32 >) -> tensor <3 x2 x4 xf32 > {
22
+ func.func @main3 (%a : tensor <3 x4 xf32 >, %b : tensor <3 x4 xf32 >) -> tensor <3 x2 x4 xf32 > {
14
23
%u = stablehlo.reshape %a : (tensor <3 x4 xf32 >) -> tensor <3 x1 x4 xf32 >
15
24
%v = stablehlo.reshape %b : (tensor <3 x4 xf32 >) -> tensor <3 x1 x4 xf32 >
16
25
%concat = stablehlo.concatenate %u , %v , dim =1 : (tensor <3 x1 x4 xf32 >, tensor <3 x1 x4 xf32 >) -> tensor <3 x2 x4 xf32 >
@@ -24,3 +33,11 @@ module {
24
33
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<6x4xf32>) -> tensor<2x3x4xf32>
25
34
// CHECK-NEXT: return %1 : tensor<2x3x4xf32>
26
35
// CHECK-NEXT: }
36
+
37
+ // CHECK: func.func @main2(%arg0: tensor<3x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x3x4xf64> {
38
+ // CHECK-NEXT: %0 = stablehlo.convert %arg0 : (tensor<3x4xf32>) -> tensor<3x4xf64>
39
+ // CHECK-NEXT: %1 = stablehlo.convert %arg1 : (tensor<3x4xf32>) -> tensor<3x4xf64>
40
+ // CHECK-NEXT: %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<3x4xf64>, tensor<3x4xf64>) -> tensor<6x4xf64>
41
+ // CHECK-NEXT: %3 = stablehlo.reshape %2 : (tensor<6x4xf64>) -> tensor<2x3x4xf64>
42
+ // CHECK-NEXT: return %3 : tensor<2x3x4xf64>
43
+ // CHECK-NEXT: }
0 commit comments