Skip to content

Commit 2d2938a

Browse files
ChaiBapchyaaaronmarkham
authored andcommitted
Faster Transpose 2D (apache#16104)
* 2d transpose naive * omp pragma * omp pragma unroll * blocksize * make it 2d tile * loop peeling * better loop peeling * redundancy * removed bool * removing excess for loops, memory save * fix internal forloop * remove commented code, lint fix * Trigger notification * explain params, indent fix, explain blocksize * fix p,n and reduce for loop computation j+a,i+b * kernel * gpu thread 1 * remove gpu implementation * fix internal for loop * unittest to catch the previous error * optimizations * microsoft cpp doesn't support omp collapse
1 parent d1897a6 commit 2d2938a

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

src/operator/tensor/matrix_op-inl.h

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,51 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
257257
}
258258
};
259259

260+
261+
/*!
262+
* \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache
263+
* \param in input tensor
264+
* \param out output tensor
265+
* \param row shape of dim 0 of input
266+
* \param col shape of dim 1 of input
267+
*/
268+
template<typename DType>
269+
MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) {
270+
// ensure cache line hits and prevent cache miss for any configuration
271+
// L1 cache size to be utilized = 32kb = 2^15
272+
// Largest size of a single unit of any dtype <= 8 byte = 2^3
273+
// Number of elements - (2^15/2^3) = 2^12
274+
// Block-size - 2^6 v 2^6 (64 v 64)
275+
276+
// But we could leverage unrolling of for loops (for parallelization)
277+
// Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
278+
// blocksize * blocksize * num_threads = cache_size / dtype_size
279+
// Instead of explicit unroll, let compiler figure out optimal unroll factor
280+
index_t blocksize = 32;
281+
282+
// collapse 2 parallelizes 2 for loops
283+
// inner 2 for loops aren't parallelized to prevent cache miss
284+
285+
// Microsoft Visual C++ compiler does not support omp collapse
286+
#ifdef _MSC_VER
287+
#pragma omp parallel for
288+
#else
289+
#pragma omp parallel for collapse(2)
290+
#endif // _MSC_VER
291+
292+
for (index_t i = 0; i < row; i += blocksize) {
293+
for (index_t j = 0; j < col; j += blocksize) {
294+
// transpose the block
295+
for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
296+
for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
297+
out[a * row + b] = in[b * col + a];
298+
}
299+
}
300+
}
301+
}
302+
}
303+
304+
260305
template<typename xpu>
261306
void TransposeImpl(RunContext ctx,
262307
const TBlob& src,
@@ -285,8 +330,13 @@ void TransposeImpl(RunContext ctx,
285330
case 2: {
286331
mshadow::Tensor<xpu, 2, DType> in = src.FlatTo2D<xpu, DType>(s);
287332
mshadow::Tensor<xpu, 2, DType> out = ret.FlatTo2D<xpu, DType>(s);
333+
288334
if (axes[0] == 1 && axes[1] == 0) {
289-
out = in.T();
335+
if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
336+
Transpose2D<DType>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);
337+
} else {
338+
out = in.T();
339+
}
290340
} else {
291341
Copy(out, in, s);
292342
}

tests/python/unittest/test_operator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2874,6 +2874,13 @@ def test_transpose():
28742874
assert_allclose(np.transpose(x.asnumpy()), y.asnumpy())
28752875

28762876

2877+
@with_seed()
2878+
def test_larger_transpose():
2879+
x = mx.nd.random.normal(shape=(50,51))
2880+
y = mx.nd.transpose(x)
2881+
assert_allclose(np.transpose(x.asnumpy()), y.asnumpy())
2882+
2883+
28772884
@with_seed()
28782885
def test_expand_dims():
28792886
for ndim in range(1, 6):

0 commit comments

Comments
 (0)