@@ -257,6 +257,51 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
257
257
}
258
258
};
259
259
260
+
261
+ /* !
262
+ * \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache
263
+ * \param in input tensor
264
+ * \param out output tensor
265
+ * \param row shape of dim 0 of input
266
+ * \param col shape of dim 1 of input
267
+ */
268
+ template <typename DType>
269
+ MSHADOW_XINLINE void Transpose2D (const DType *in, DType *out, index_t row, index_t col) {
270
+ // ensure cache line hits and prevent cache miss for any configuration
271
+ // L1 cache size to be utilized = 32kb = 2^15
272
+ // Largest size of a single unit of any dtype <= 8 byte = 2^3
273
+ // Number of elements - (2^15/2^3) = 2^12
274
+ // Block-size - 2^6 v 2^6 (64 v 64)
275
+
276
+ // But we could leverage unrolling of for loops (for parallelization)
277
+ // Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
278
+ // blocksize * blocksize * num_threads = cache_size / dtype_size
279
+ // Instead of explicit unroll, let compiler figure out optimal unroll factor
280
+ index_t blocksize = 32 ;
281
+
282
+ // collapse 2 parallelizes 2 for loops
283
+ // inner 2 for loops aren't parallelized to prevent cache miss
284
+
285
+ // Microsoft Visual C++ compiler does not support omp collapse
286
+ #ifdef _MSC_VER
287
+ #pragma omp parallel for
288
+ #else
289
+ #pragma omp parallel for collapse(2)
290
+ #endif // _MSC_VER
291
+
292
+ for (index_t i = 0 ; i < row; i += blocksize) {
293
+ for (index_t j = 0 ; j < col; j += blocksize) {
294
+ // transpose the block
295
+ for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
296
+ for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
297
+ out[a * row + b] = in[b * col + a];
298
+ }
299
+ }
300
+ }
301
+ }
302
+ }
303
+
304
+
260
305
template <typename xpu>
261
306
void TransposeImpl (RunContext ctx,
262
307
const TBlob& src,
@@ -285,8 +330,13 @@ void TransposeImpl(RunContext ctx,
285
330
case 2 : {
286
331
mshadow::Tensor<xpu, 2 , DType> in = src.FlatTo2D <xpu, DType>(s);
287
332
mshadow::Tensor<xpu, 2 , DType> out = ret.FlatTo2D <xpu, DType>(s);
333
+
288
334
if (axes[0 ] == 1 && axes[1 ] == 0 ) {
289
- out = in.T ();
335
+ if (ctx.get_ctx ().dev_mask () == cpu::kDevMask ) {
336
+ Transpose2D<DType>(in.dptr_ , out.dptr_ , in.shape_ [0 ], in.shape_ [1 ]);
337
+ } else {
338
+ out = in.T ();
339
+ }
290
340
} else {
291
341
Copy (out, in, s);
292
342
}
0 commit comments