@@ -85,32 +85,30 @@ module {
85
85
// A kernel that computes a BSR sampled dense matrix matrix multiplication
86
86
// using a "spy" function and in-place update of the sampling sparse matrix.
87
87
//
88
- // TODO: re-enable the following test.
89
- //
90
- // func.func @SDDMM_block(%args: tensor<?x?xf32, #BSR>,
91
- // %arga: tensor<?x?xf32>,
92
- // %argb: tensor<?x?xf32>) -> tensor<?x?xf32, #BSR> {
93
- // %result = linalg.generic #trait_SDDMM
94
- // ins(%arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>)
95
- // outs(%args: tensor<?x?xf32, #BSR>) {
96
- // ^bb(%a: f32, %b: f32, %s: f32):
97
- // %f0 = arith.constant 0.0 : f32
98
- // %u = sparse_tensor.unary %s : f32 to f32
99
- // present={
100
- // ^bb0(%p: f32):
101
- // %mul = arith.mulf %a, %b : f32
102
- // sparse_tensor.yield %mul : f32
103
- // }
104
- // absent={}
105
- // %r = sparse_tensor.reduce %s, %u, %f0 : f32 {
106
- // ^bb0(%p: f32, %q: f32):
107
- // %add = arith.addf %p, %q : f32
108
- // sparse_tensor.yield %add : f32
109
- // }
110
- // linalg.yield %r : f32
111
- // } -> tensor<?x?xf32, #BSR>
112
- // return %result : tensor<?x?xf32, #BSR>
113
- // }
88
+ func.func @SDDMM_block (%args: tensor <?x?xf32 , #BSR >,
89
+ %arga: tensor <?x?xf32 >,
90
+ %argb: tensor <?x?xf32 >) -> tensor <?x?xf32 , #BSR > {
91
+ %result = linalg.generic #trait_SDDMM
92
+ ins (%arga , %argb: tensor <?x?xf32 >, tensor <?x?xf32 >)
93
+ outs (%args: tensor <?x?xf32 , #BSR >) {
94
+ ^bb (%a: f32 , %b: f32 , %s: f32 ):
95
+ %f0 = arith.constant 0.0 : f32
96
+ %u = sparse_tensor.unary %s : f32 to f32
97
+ present ={
98
+ ^bb0 (%p: f32 ):
99
+ %mul = arith.mulf %a , %b : f32
100
+ sparse_tensor.yield %mul : f32
101
+ }
102
+ absent ={}
103
+ %r = sparse_tensor.reduce %s , %u , %f0 : f32 {
104
+ ^bb0 (%p: f32 , %q: f32 ):
105
+ %add = arith.addf %p , %q : f32
106
+ sparse_tensor.yield %add : f32
107
+ }
108
+ linalg.yield %r : f32
109
+ } -> tensor <?x?xf32 , #BSR >
110
+ return %result : tensor <?x?xf32 , #BSR >
111
+ }
114
112
115
113
func.func private @getTensorFilename (index ) -> (!Filename )
116
114
@@ -153,15 +151,15 @@ module {
153
151
//
154
152
%fileName = call @getTensorFilename (%c0 ) : (index ) -> (!Filename )
155
153
%m_csr = sparse_tensor.new %fileName : !Filename to tensor <?x?xf32 , #CSR >
156
- // %m_bsr = sparse_tensor.new %fileName : !Filename to tensor<?x?xf32, #BSR>
154
+ %m_bsr = sparse_tensor.new %fileName : !Filename to tensor <?x?xf32 , #BSR >
157
155
158
156
// Call the kernel.
159
157
%0 = call @SDDMM (%m_csr , %a , %b )
160
158
: (tensor <?x?xf32 , #CSR >,
161
159
tensor <?x?xf32 >, tensor <?x?xf32 >) -> tensor <?x?xf32 , #CSR >
162
- // %1 = call @SDDMM_block(%m_bsr, %a, %b)
163
- // : (tensor<?x?xf32, #BSR>,
164
- // tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32, #BSR>
160
+ %1 = call @SDDMM_block (%m_bsr , %a , %b )
161
+ : (tensor <?x?xf32 , #BSR >,
162
+ tensor <?x?xf32 >, tensor <?x?xf32 >) -> tensor <?x?xf32 , #BSR >
165
163
166
164
//
167
165
// Print the result for verification. Note that the "spy" determines what
@@ -170,18 +168,18 @@ module {
170
168
// in the original zero positions).
171
169
//
172
170
// CHECK: ( 5, 10, 24, 19, 53, 42, 55, 56 )
173
- // C_HECK -NEXT: ( 5, 10, 8, 19, 24, 24, 40, 53, 42, 55, 56, 64 )
171
+ // CHECK -NEXT: ( 5, 10, 8, 19, 24, 24, 40, 53, 42, 55, 56, 64 )
174
172
//
175
173
%v0 = sparse_tensor.values %0 : tensor <?x?xf32 , #CSR > to memref <?xf32 >
176
174
%vv0 = vector.transfer_read %v0 [%c0 ], %d0 : memref <?xf32 >, vector <8 xf32 >
177
175
vector.print %vv0 : vector <8 xf32 >
178
- // %v1 = sparse_tensor.values %1 : tensor<?x?xf32, #BSR> to memref<?xf32>
179
- // %vv1 = vector.transfer_read %v1[%c0], %d0 : memref<?xf32>, vector<12xf32>
180
- // vector.print %vv1 : vector<12xf32>
176
+ %v1 = sparse_tensor.values %1 : tensor <?x?xf32 , #BSR > to memref <?xf32 >
177
+ %vv1 = vector.transfer_read %v1 [%c0 ], %d0 : memref <?xf32 >, vector <12 xf32 >
178
+ vector.print %vv1 : vector <12 xf32 >
181
179
182
180
// Release the resources.
183
181
bufferization.dealloc_tensor %0 : tensor <?x?xf32 , #CSR >
184
- // bufferization.dealloc_tensor %1 : tensor<?x?xf32, #BSR>
182
+ bufferization.dealloc_tensor %1 : tensor <?x?xf32 , #BSR >
185
183
186
184
llvm.call @mgpuDestroySparseEnv () : () -> ()
187
185
return
0 commit comments