Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[WIP] Transpose 2D GPU #16706

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 8 additions & 43 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,55 +262,18 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
}
};


/*!
* \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache
* \param in input tensor
* \param out output tensor
* \param row shape of dim 0 of input
* \param col shape of dim 1 of input
* \tparam DType Data type
* \tparam is_addto
*/
template<typename DType, bool is_addto>
MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) {
// ensure cache line hits and prevent cache miss for any configuration
// L1 cache size to be utilized = 32kb = 2^15
// Largest size of a single unit of any dtype <= 8 byte = 2^3
// Number of elements - (2^15/2^3) = 2^12
// Block-size - 2^6 v 2^6 (64 v 64)

// But we could leverage unrolling of for loops (for parallelization)
// Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
// blocksize * blocksize * num_threads = cache_size / dtype_size
// Instead of explicit unroll, let compiler figure out optimal unroll factor
const index_t blocksize = 32;

// collapse 2 parallelizes 2 for loops
// inner 2 for loops aren't parallelized to prevent cache miss

// Microsoft Visual C++ compiler does not support omp collapse
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif // _MSC_VER

for (index_t i = 0; i < row; i += blocksize) {
for (index_t j = 0; j < col; j += blocksize) {
// transpose the block
for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
if (!is_addto) {
out[a * row + b] = in[b * col + a];
} else {
out[a * row + b] += in[b * col + a];

}
}
}
}
}
}


template<typename DType, typename xpu>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, void>::type
Transpose2D(const DType *in, DType *out, index_t row, index_t col);

inline bool IsIdentityTranspose(const TShape& axes) {
for (dim_t i = 0; i < axes.ndim(); i++) {
Expand Down Expand Up @@ -360,10 +323,12 @@ void TransposeImpl(RunContext ctx,
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
switch (axes.ndim()) {
case 2: {

Tensor<xpu, 2, DType> in = src.get<xpu, 2, DType>(s);
Tensor<xpu, 2, DType> out = ret.get<xpu, 2, DType>(s);
if (ctx.get_ctx().dev_mask() == cpu::kDevMask) {
Transpose2D<DType, is_addto>(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]);

} else {
LOG(FATAL) << "Not Implemented. We should never reach here because the 2D case "
"in GPU has been covered by transpose_pseudo2D."
Expand Down
43 changes: 43 additions & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,49 @@ inline static bool TransposeStorageType(const nnvm::NodeAttrs& attrs,
}
#endif

/*!
* \brief This function performs transpose operation on a 2D matrix by utilizing the L1 cache
* \param in input tensor
* \param out output tensor
* \param row shape of dim 0 of input
* \param col shape of dim 1 of input
*/
template<typename DType>
inline void Transpose2D(const DType *in, DType *out, index_t row, index_t col) {
// ensure cache line hits and prevent cache miss for any configuration
// L1 cache size to be utilized = 32kb = 2^15
// Largest size of a single unit of any dtype <= 8 byte = 2^3
// Number of elements - (2^15/2^3) = 2^12
// Block-size - 2^6 v 2^6 (64 v 64)

// But we could leverage unrolling of for loops (for parallelization)
// Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled
// blocksize * blocksize * num_threads = cache_size / dtype_size
// Instead of explicit unroll, let compiler figure out optimal unroll factor
index_t blocksize = 32;

// collapse 2 parallelizes 2 for loops
// inner 2 for loops aren't parallelized to prevent cache miss

// Microsoft Visual C++ compiler does not support omp collapse
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif // _MSC_VER

for (index_t i = 0; i < row; i += blocksize) {
for (index_t j = 0; j < col; j += blocksize) {
// transpose the block
for (index_t a = j; (a < blocksize + j) && (a < col); ++a) {
for (index_t b = i; (b < blocksize + i) && (b < row); ++b) {
out[a * row + b] = in[b * col + a];
}
}
}
}
}

NNVM_REGISTER_OP(transpose)
.describe(R"code(Permutes the dimensions of an array.
Examples::
Expand Down
1 change: 0 additions & 1 deletion src/operator/tensor/matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include "./matrix_op-inl.h"
#include "./elemwise_unary_op.h"


namespace mxnet {
namespace op {

Expand Down
78 changes: 78 additions & 0 deletions src/operator/tensor/transpose_op-inl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file transpose_op-inl.cuh
* \brief Function definition for cuda transpose
* \author Chaitanya Bapat
*/

#ifndef MXNET_OPERATOR_TENSOR_TRANSPOSE_OP_INL_CUH_
#define MXNET_OPERATOR_TENSOR_TRANSPOSE_OP_INL_CUH_

#include <mxnet/tuple.h>
#include <mxnet/tensor_blob.h>
#include <mshadow/base.h>
#include <algorithm>
#include <utility>
#include "../../common/cuda_utils.h"

namespace mxnet {
namespace op {
namespace mshadow {
namespace cuda {
template<typename DType>
__global__ void Transpose2DKernel(const DType *in, DType *out, index_t row, index_t col) {
const index_t TILE_DIM = 32;
const index_t BLOCK_ROWS = 8;
__shared__ DType tile[TILE_DIM][TILE_DIM + 1];

index_t x = blockIdx.x * TILE_DIM + threadIdx.x;
index_t y = blockIdx.y * TILE_DIM + threadIdx.y;

for (index_t j = 0; j < TILE_DIM; j += BLOCK_ROWS)
tile[threadIdx.y+j][threadIdx.x] = in[(y+j)*col + x];

__syncthreads();

x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset
y = blockIdx.x * TILE_DIM + threadIdx.y;

for (index_t j = 0; j < TILE_DIM; j += BLOCK_ROWS)
out[(y+j)*row+ x] = tile[threadIdx.x][threadIdx.y + j];
}
} // namespace cuda
} // namespace mshadow

template<typename DType, typenamme xpu>
inline typename std::enable_if<std::is_same<xpu, gpu>::value, void>::type
Transpose2D(const DType *in, DType *out, index_t row, index_t col) {
using namespace mshadow::cuda;
dim3 grid(32);
dim3 block(8);
Transpose2DKernel<DType><<<grid, block>>>(in, out, row, col);
MSHADOW_CUDA_POST_KERNEL_CHECK(Transpose2DKernel);
}

} // namespace op
} // namespace mxnet


#endif // MXNET_OPERATOR_TENSOR_TRANSPOSE_OP_INL_CUH_