Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 8096421

Browse files
MoisesHerptrendx
authored andcommitted
Embedding gradient performance optimization on GPU (#16355)
* Add Embedding backward Op for GPU * Add some code documentation * Use unnamed namespace for integer log2 function * Fix lint issues * Fix one more lint problem * Remove unnecessary conditions ops * Fix one more lint problem
1 parent 916fbf2 commit 8096421

File tree

1 file changed

+233
-0
lines changed

1 file changed

+233
-0
lines changed

src/operator/tensor/indexing_op.cu

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,239 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
545545
});
546546
}
547547

548+
namespace {
549+
/*
550+
* \brief returns integer log2(a) rounded up
551+
*/
552+
inline int ilog2(unsigned int a) {
553+
int k = 1;
554+
while (a >>= 1) k++;
555+
return k;
556+
}
557+
}
558+
559+
/*
560+
* \brief finds the lower and upper-bound positions of each unique element within a sorted input array
561+
* \param sorted_data input elements previously sorted
562+
* \param bounds output containing all lower-bound followed by all upper-bound positions
563+
* \param data_dim total number of elements in the input array
564+
* \param vocab_dim maximum number of unique elements
565+
*/
566+
template <typename IType>
567+
__global__ void EmbeddingFindBounds(const IType *sorted_data,
568+
IType *bounds,
569+
const index_t data_dim,
570+
const index_t vocab_dim) {
571+
const index_t id = blockIdx.x * blockDim.x + threadIdx.x;
572+
if (id >= vocab_dim) return;
573+
574+
// Binary search to find lower bound: stored at bounds[0..vocab_dim-1]
575+
IType lower_bound = 0;
576+
IType upper_bound = data_dim - 1;
577+
IType mean;
578+
while (lower_bound < upper_bound) {
579+
mean = (lower_bound + upper_bound) / 2;
580+
if (id <= sorted_data[mean])
581+
upper_bound = mean;
582+
else
583+
lower_bound = mean + 1;
584+
}
585+
bool found_row = (sorted_data[lower_bound] == id);
586+
if (!found_row) {
587+
bounds[id] = -1;
588+
bounds[vocab_dim + id] = -2;
589+
return;
590+
} else {
591+
bounds[id] = lower_bound;
592+
}
593+
594+
// Binary search to find upper bound: stored at bounds[vocab_dim..2*vocab_dim-1]
595+
lower_bound = 0;
596+
upper_bound = data_dim - 1;
597+
while (lower_bound < upper_bound) {
598+
mean = (lower_bound + upper_bound + 1) / 2;
599+
if (id >= sorted_data[mean])
600+
lower_bound = mean;
601+
else
602+
upper_bound = mean - 1;
603+
}
604+
bounds[vocab_dim + id] = upper_bound;
605+
}
606+
607+
/*
608+
* \brief kernel to compute gradient of EmbeddingOp
609+
* \param grad_in input gradient data
610+
* \param original_index reference to the position at original input data for each index
611+
* \param index_bounds lower and upper-bounds positions of each unique index
612+
* \param grad_out output gradient data
613+
* \param embbedding_dim dimension of the dense embedding
614+
* \param vocab_dim maximum number of unique indices in the data array: tokens vocabulary size
615+
* \param req write/add/null
616+
*/
617+
template <typename LType, typename DType, typename IType>
618+
__global__ void EmbeddingGradKernel(DType *grad_in,
619+
const IType *original_index,
620+
const IType *index_bounds,
621+
const DType *grad_out,
622+
const index_t embbedding_dim,
623+
const index_t vocab_dim,
624+
const int req) {
625+
extern __shared__ int sharedmem[];
626+
LType* grad_in_row = reinterpret_cast<LType *>(sharedmem);
627+
628+
// LType has to be bigger than DType, guarded in the launcher code
629+
const int n_val = sizeof(DType) < sizeof(LType) ? sizeof(LType) / sizeof(DType) : 1;
630+
const LType *aligned_grad_out = reinterpret_cast<const LType *>(grad_out);
631+
LType *aligned_grad_in = reinterpret_cast<LType *>(grad_in);
632+
const index_t aligned_emb_dim = embbedding_dim / n_val;
633+
DType *my_grad_in_row = reinterpret_cast<DType *>(&grad_in_row[threadIdx.x]);
634+
LType Lvalue[1];
635+
DType* Dvalues = reinterpret_cast<DType*>(Lvalue);
636+
637+
IType my_row = blockIdx.x;
638+
if (my_row < vocab_dim) {
639+
// Read lower and upper bounds for current row
640+
IType lower_bound = index_bounds[my_row];
641+
IType upper_bound = index_bounds[vocab_dim + my_row];
642+
int nOccurrences = upper_bound - lower_bound + 1;
643+
644+
for (index_t emb_id=threadIdx.x; emb_id < aligned_emb_dim; emb_id += blockDim.x) {
645+
// Initialize grad_in
646+
if (req == kAddTo) {
647+
grad_in_row[threadIdx.x] = aligned_grad_in[my_row * aligned_emb_dim + emb_id];
648+
} else {
649+
grad_in_row[threadIdx.x] = 0.0;
650+
}
651+
// Add all rows from grad_out according to indices in data
652+
for (index_t data_idx=lower_bound; data_idx < (lower_bound + nOccurrences); ++data_idx) {
653+
*Lvalue = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id];
654+
for (index_t val_id = 0; val_id < n_val; val_id++) {
655+
my_grad_in_row[val_id] += Dvalues[val_id];
656+
}
657+
}
658+
// Save results
659+
aligned_grad_in[my_row * aligned_emb_dim + emb_id] = grad_in_row[threadIdx.x];
660+
}
661+
}
662+
}
663+
664+
template<typename gpu, typename IType, typename DType>
665+
void EmbeddingGradKernelCaller(const OpContext& ctx,
666+
mshadow::Tensor<gpu, 2, DType> grad_in,
667+
const mshadow::Tensor<gpu, 1, IType>& index,
668+
const mshadow::Tensor<gpu, 2, DType> &grad_out,
669+
const std::vector<OpReqType>& req) {
670+
using namespace mxnet_op;
671+
using namespace mshadow::expr;
672+
673+
Stream<gpu> *s = ctx.get_stream<gpu>();
674+
const index_t data_dim = index.shape_[0];
675+
const index_t vocab_dim = grad_in.shape_[0];
676+
const index_t embbedding_dim = grad_in.shape_[1];
677+
678+
// Calculate amount of temporary storage
679+
size_t sort_workspace_size = mxnet::op::SortByKeyWorkspaceSize<int, int, gpu>
680+
(data_dim);
681+
size_t workspace_size = 2 * data_dim * sizeof(int) +
682+
2 * vocab_dim * sizeof(int) + sort_workspace_size;
683+
684+
// Request temporary storage
685+
Tensor<gpu, 1, char> workspace =
686+
ctx.requested[embedding::kTempSpace].get_space_typed<gpu, 1, char>(
687+
Shape1(workspace_size), s);
688+
689+
// Create tensors
690+
size_t pos = 0;
691+
Tensor<gpu, 1, int> sorted_data(reinterpret_cast<int*>(&workspace[pos]),
692+
Shape1(data_dim), s);
693+
pos += data_dim * sizeof(int);
694+
// Reference to input data positions for each element of sorted_data
695+
Tensor<gpu, 1, int> original_index(reinterpret_cast<int*>(&workspace[pos]),
696+
Shape1(data_dim), s);
697+
pos += data_dim * sizeof(int);
698+
// lower and upper bound positions of each index within sorted_data
699+
Tensor<gpu, 1, int> bounds_index(reinterpret_cast<int*>(&workspace[pos]),
700+
Shape1(2 * vocab_dim), s);
701+
pos += 2 * vocab_dim * sizeof(int);
702+
Tensor<gpu, 1, char> Sort_temp_storage(&workspace[pos], Shape1(sort_workspace_size), s);
703+
704+
// Clip indices [0, vocab_dim-1]
705+
Kernel<tcast_clip, gpu>::Launch(s, data_dim, sorted_data.dptr_, index.dptr_,
706+
static_cast<int>(vocab_dim));
707+
708+
Kernel<range_fwd, gpu>::Launch(s, data_dim,
709+
1, 0, 1, kWriteTo, original_index.dptr_);
710+
711+
// Sort indices array
712+
int num_bits = ilog2((vocab_dim - 1));
713+
mxnet::op::SortByKey(sorted_data, original_index, true, &Sort_temp_storage, 0, num_bits);
714+
715+
// Find lower & upper bounds of each possible index
716+
const int threads_block_bounds = 128;
717+
const int nblocks_bounds = (vocab_dim + threads_block_bounds - 1) / threads_block_bounds;
718+
EmbeddingFindBounds<<<nblocks_bounds, threads_block_bounds, 0, Stream<gpu>::GetStream(s)>>>(
719+
sorted_data.dptr_, bounds_index.dptr_, data_dim, vocab_dim);
720+
721+
// Compute Gradient
722+
int ltype = mxnet::common::cuda::get_load_type(embbedding_dim * sizeof(DType));
723+
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
724+
int nelems_per_thread = sizeof(LType) / sizeof(DType);
725+
int threads_block_grad = 32;
726+
int maxThreads = 1024;
727+
while (threads_block_grad < (embbedding_dim/nelems_per_thread) &&
728+
(threads_block_grad < maxThreads))
729+
threads_block_grad += 32;
730+
size_t required_shared = threads_block_grad * sizeof(LType);
731+
dim3 blocks(vocab_dim, 1);
732+
EmbeddingGradKernel<LType><<<blocks, threads_block_grad, required_shared,
733+
Stream<gpu>::GetStream(s)>>>(
734+
grad_in.dptr_, original_index.dptr_,
735+
bounds_index.dptr_, grad_out.dptr_,
736+
embbedding_dim, vocab_dim,
737+
req[embedding::kWeight]);
738+
});
739+
}
740+
741+
template<>
742+
void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
743+
const OpContext& ctx,
744+
const std::vector<TBlob>& inputs,
745+
const std::vector<OpReqType>& req,
746+
const std::vector<TBlob>& outputs) {
747+
using namespace mshadow;
748+
using namespace mshadow::expr;
749+
CHECK_EQ(inputs.size(), 2U);
750+
CHECK_EQ(outputs.size(), 2U);
751+
CHECK_EQ(req[embedding::kData], kNullOp)
752+
<< "Embedding layer doesn't support calculate data gradient";
753+
if (req[embedding::kWeight] == kNullOp) {
754+
return;
755+
}
756+
CHECK_EQ(outputs[1].type_flag_, inputs[0].type_flag_);
757+
758+
const mxnet::TShape& ishape = inputs[1].shape_;
759+
const mxnet::TShape& oshape = inputs[0].shape_;
760+
761+
Stream<gpu> *s = ctx.get_stream<gpu>();
762+
CHECK_NE(req[embedding::kWeight], kWriteInplace)
763+
<< "Backward of Embedding does not support writing in place.";
764+
MSHADOW_TYPE_SWITCH(outputs[1].type_flag_, DType, {
765+
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {
766+
Tensor < gpu, 1, IType > data = inputs[1].get_with_shape<gpu, 1, IType>(
767+
Shape1(ishape.ProdShape(0, ishape.ndim())), s);
768+
Tensor<gpu, 2, DType> grad_out = inputs[0].get_with_shape<gpu, 2, DType>(
769+
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
770+
Tensor<gpu, 2, DType> grad_in = outputs[1].get<gpu, 2, DType>(s);
771+
772+
if (req[embedding::kWeight] == kWriteTo || req[embedding::kWeight] == kAddTo) {
773+
EmbeddingGradKernelCaller(ctx, grad_in, data, grad_out, req);
774+
} else {
775+
LOG(FATAL) << "wrong req";
776+
}
777+
});
778+
});
779+
}
780+
548781
NNVM_REGISTER_OP(Embedding)
549782
.set_attr<FCompute>("FCompute<gpu>", EmbeddingOpForward<gpu>)
550783
.set_attr<FComputeEx>("FComputeEx<gpu>", SparseEmbeddingOpForwardEx<gpu>);

0 commit comments

Comments
 (0)