@@ -545,6 +545,239 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
545
545
});
546
546
}
547
547
548
+ namespace {
549
+ /*
550
+ * \brief returns integer log2(a) rounded up
551
+ */
552
+ inline int ilog2 (unsigned int a) {
553
+ int k = 1 ;
554
+ while (a >>= 1 ) k++;
555
+ return k;
556
+ }
557
+ }
558
+
559
+ /*
560
+ * \brief finds the lower and upper-bound positions of each unique element within a sorted input array
561
+ * \param sorted_data input elements previously sorted
562
+ * \param bounds output containing all lower-bound followed by all upper-bound positions
563
+ * \param data_dim total number of elements in the input array
564
+ * \param vocab_dim maximum number of unique elements
565
+ */
566
+ template <typename IType>
567
+ __global__ void EmbeddingFindBounds (const IType *sorted_data,
568
+ IType *bounds,
569
+ const index_t data_dim,
570
+ const index_t vocab_dim) {
571
+ const index_t id = blockIdx .x * blockDim .x + threadIdx .x ;
572
+ if (id >= vocab_dim) return ;
573
+
574
+ // Binary search to find lower bound: stored at bounds[0..vocab_dim-1]
575
+ IType lower_bound = 0 ;
576
+ IType upper_bound = data_dim - 1 ;
577
+ IType mean;
578
+ while (lower_bound < upper_bound) {
579
+ mean = (lower_bound + upper_bound) / 2 ;
580
+ if (id <= sorted_data[mean])
581
+ upper_bound = mean;
582
+ else
583
+ lower_bound = mean + 1 ;
584
+ }
585
+ bool found_row = (sorted_data[lower_bound] == id);
586
+ if (!found_row) {
587
+ bounds[id] = -1 ;
588
+ bounds[vocab_dim + id] = -2 ;
589
+ return ;
590
+ } else {
591
+ bounds[id] = lower_bound;
592
+ }
593
+
594
+ // Binary search to find upper bound: stored at bounds[vocab_dim..2*vocab_dim-1]
595
+ lower_bound = 0 ;
596
+ upper_bound = data_dim - 1 ;
597
+ while (lower_bound < upper_bound) {
598
+ mean = (lower_bound + upper_bound + 1 ) / 2 ;
599
+ if (id >= sorted_data[mean])
600
+ lower_bound = mean;
601
+ else
602
+ upper_bound = mean - 1 ;
603
+ }
604
+ bounds[vocab_dim + id] = upper_bound;
605
+ }
606
+
607
+ /*
608
+ * \brief kernel to compute gradient of EmbeddingOp
609
+ * \param grad_in input gradient data
610
+ * \param original_index reference to the position at original input data for each index
611
+ * \param index_bounds lower and upper-bounds positions of each unique index
612
+ * \param grad_out output gradient data
613
+ * \param embbedding_dim dimension of the dense embedding
614
+ * \param vocab_dim maximum number of unique indices in the data array: tokens vocabulary size
615
+ * \param req write/add/null
616
+ */
617
+ template <typename LType, typename DType, typename IType>
618
+ __global__ void EmbeddingGradKernel (DType *grad_in,
619
+ const IType *original_index,
620
+ const IType *index_bounds,
621
+ const DType *grad_out,
622
+ const index_t embbedding_dim,
623
+ const index_t vocab_dim,
624
+ const int req) {
625
+ extern __shared__ int sharedmem[];
626
+ LType* grad_in_row = reinterpret_cast <LType *>(sharedmem);
627
+
628
+ // LType has to be bigger than DType, guarded in the launcher code
629
+ const int n_val = sizeof (DType) < sizeof (LType) ? sizeof (LType) / sizeof (DType) : 1 ;
630
+ const LType *aligned_grad_out = reinterpret_cast <const LType *>(grad_out);
631
+ LType *aligned_grad_in = reinterpret_cast <LType *>(grad_in);
632
+ const index_t aligned_emb_dim = embbedding_dim / n_val;
633
+ DType *my_grad_in_row = reinterpret_cast <DType *>(&grad_in_row[threadIdx .x ]);
634
+ LType Lvalue[1 ];
635
+ DType* Dvalues = reinterpret_cast <DType*>(Lvalue);
636
+
637
+ IType my_row = blockIdx .x ;
638
+ if (my_row < vocab_dim) {
639
+ // Read lower and upper bounds for current row
640
+ IType lower_bound = index_bounds[my_row];
641
+ IType upper_bound = index_bounds[vocab_dim + my_row];
642
+ int nOccurrences = upper_bound - lower_bound + 1 ;
643
+
644
+ for (index_t emb_id=threadIdx .x ; emb_id < aligned_emb_dim; emb_id += blockDim .x ) {
645
+ // Initialize grad_in
646
+ if (req == kAddTo ) {
647
+ grad_in_row[threadIdx .x ] = aligned_grad_in[my_row * aligned_emb_dim + emb_id];
648
+ } else {
649
+ grad_in_row[threadIdx .x ] = 0.0 ;
650
+ }
651
+ // Add all rows from grad_out according to indices in data
652
+ for (index_t data_idx=lower_bound; data_idx < (lower_bound + nOccurrences); ++data_idx) {
653
+ *Lvalue = aligned_grad_out[original_index[data_idx] * aligned_emb_dim + emb_id];
654
+ for (index_t val_id = 0 ; val_id < n_val; val_id++) {
655
+ my_grad_in_row[val_id] += Dvalues[val_id];
656
+ }
657
+ }
658
+ // Save results
659
+ aligned_grad_in[my_row * aligned_emb_dim + emb_id] = grad_in_row[threadIdx .x ];
660
+ }
661
+ }
662
+ }
663
+
664
+ template <typename gpu, typename IType, typename DType>
665
+ void EmbeddingGradKernelCaller (const OpContext& ctx,
666
+ mshadow::Tensor<gpu, 2 , DType> grad_in,
667
+ const mshadow::Tensor<gpu, 1 , IType>& index,
668
+ const mshadow::Tensor<gpu, 2 , DType> &grad_out,
669
+ const std::vector<OpReqType>& req) {
670
+ using namespace mxnet_op ;
671
+ using namespace mshadow ::expr;
672
+
673
+ Stream<gpu> *s = ctx.get_stream <gpu>();
674
+ const index_t data_dim = index.shape_ [0 ];
675
+ const index_t vocab_dim = grad_in.shape_ [0 ];
676
+ const index_t embbedding_dim = grad_in.shape_ [1 ];
677
+
678
+ // Calculate amount of temporary storage
679
+ size_t sort_workspace_size = mxnet::op::SortByKeyWorkspaceSize<int , int , gpu>
680
+ (data_dim);
681
+ size_t workspace_size = 2 * data_dim * sizeof (int ) +
682
+ 2 * vocab_dim * sizeof (int ) + sort_workspace_size;
683
+
684
+ // Request temporary storage
685
+ Tensor<gpu, 1 , char > workspace =
686
+ ctx.requested [embedding::kTempSpace ].get_space_typed <gpu, 1 , char >(
687
+ Shape1 (workspace_size), s);
688
+
689
+ // Create tensors
690
+ size_t pos = 0 ;
691
+ Tensor<gpu, 1 , int > sorted_data (reinterpret_cast <int *>(&workspace[pos]),
692
+ Shape1 (data_dim), s);
693
+ pos += data_dim * sizeof (int );
694
+ // Reference to input data positions for each element of sorted_data
695
+ Tensor<gpu, 1 , int > original_index (reinterpret_cast <int *>(&workspace[pos]),
696
+ Shape1 (data_dim), s);
697
+ pos += data_dim * sizeof (int );
698
+ // lower and upper bound positions of each index within sorted_data
699
+ Tensor<gpu, 1 , int > bounds_index (reinterpret_cast <int *>(&workspace[pos]),
700
+ Shape1 (2 * vocab_dim), s);
701
+ pos += 2 * vocab_dim * sizeof (int );
702
+ Tensor<gpu, 1 , char > Sort_temp_storage (&workspace[pos], Shape1 (sort_workspace_size), s);
703
+
704
+ // Clip indices [0, vocab_dim-1]
705
+ Kernel<tcast_clip, gpu>::Launch (s, data_dim, sorted_data.dptr_ , index.dptr_ ,
706
+ static_cast <int >(vocab_dim));
707
+
708
+ Kernel<range_fwd, gpu>::Launch (s, data_dim,
709
+ 1 , 0 , 1 , kWriteTo , original_index.dptr_ );
710
+
711
+ // Sort indices array
712
+ int num_bits = ilog2 ((vocab_dim - 1 ));
713
+ mxnet::op::SortByKey (sorted_data, original_index, true , &Sort_temp_storage, 0 , num_bits);
714
+
715
+ // Find lower & upper bounds of each possible index
716
+ const int threads_block_bounds = 128 ;
717
+ const int nblocks_bounds = (vocab_dim + threads_block_bounds - 1 ) / threads_block_bounds;
718
+ EmbeddingFindBounds<<<nblocks_bounds, threads_block_bounds, 0 , Stream<gpu>::GetStream(s)>>> (
719
+ sorted_data.dptr_ , bounds_index.dptr_ , data_dim, vocab_dim);
720
+
721
+ // Compute Gradient
722
+ int ltype = mxnet::common::cuda::get_load_type (embbedding_dim * sizeof (DType));
723
+ MXNET_LOAD_TYPE_SWITCH (ltype, LType, {
724
+ int nelems_per_thread = sizeof (LType) / sizeof (DType);
725
+ int threads_block_grad = 32 ;
726
+ int maxThreads = 1024 ;
727
+ while (threads_block_grad < (embbedding_dim/nelems_per_thread) &&
728
+ (threads_block_grad < maxThreads))
729
+ threads_block_grad += 32 ;
730
+ size_t required_shared = threads_block_grad * sizeof (LType);
731
+ dim3 blocks (vocab_dim, 1 );
732
+ EmbeddingGradKernel<LType><<<blocks, threads_block_grad, required_shared,
733
+ Stream<gpu>::GetStream(s)>>> (
734
+ grad_in.dptr_ , original_index.dptr_ ,
735
+ bounds_index.dptr_ , grad_out.dptr_ ,
736
+ embbedding_dim, vocab_dim,
737
+ req[embedding::kWeight ]);
738
+ });
739
+ }
740
+
741
+ template <>
742
+ void EmbeddingOpBackward<gpu>(const nnvm::NodeAttrs& attrs,
743
+ const OpContext& ctx,
744
+ const std::vector<TBlob>& inputs,
745
+ const std::vector<OpReqType>& req,
746
+ const std::vector<TBlob>& outputs) {
747
+ using namespace mshadow ;
748
+ using namespace mshadow ::expr;
749
+ CHECK_EQ (inputs.size (), 2U );
750
+ CHECK_EQ (outputs.size (), 2U );
751
+ CHECK_EQ (req[embedding::kData ], kNullOp )
752
+ << " Embedding layer doesn't support calculate data gradient" ;
753
+ if (req[embedding::kWeight ] == kNullOp ) {
754
+ return ;
755
+ }
756
+ CHECK_EQ (outputs[1 ].type_flag_ , inputs[0 ].type_flag_ );
757
+
758
+ const mxnet::TShape& ishape = inputs[1 ].shape_ ;
759
+ const mxnet::TShape& oshape = inputs[0 ].shape_ ;
760
+
761
+ Stream<gpu> *s = ctx.get_stream <gpu>();
762
+ CHECK_NE (req[embedding::kWeight ], kWriteInplace )
763
+ << " Backward of Embedding does not support writing in place." ;
764
+ MSHADOW_TYPE_SWITCH (outputs[1 ].type_flag_ , DType, {
765
+ MSHADOW_TYPE_SWITCH (inputs[1 ].type_flag_ , IType, {
766
+ Tensor < gpu, 1 , IType > data = inputs[1 ].get_with_shape <gpu, 1 , IType>(
767
+ Shape1 (ishape.ProdShape (0 , ishape.ndim ())), s);
768
+ Tensor<gpu, 2 , DType> grad_out = inputs[0 ].get_with_shape <gpu, 2 , DType>(
769
+ Shape2 (oshape.ProdShape (0 , oshape.ndim ()-1 ), oshape[oshape.ndim ()-1 ]), s);
770
+ Tensor<gpu, 2 , DType> grad_in = outputs[1 ].get <gpu, 2 , DType>(s);
771
+
772
+ if (req[embedding::kWeight ] == kWriteTo || req[embedding::kWeight ] == kAddTo ) {
773
+ EmbeddingGradKernelCaller (ctx, grad_in, data, grad_out, req);
774
+ } else {
775
+ LOG (FATAL) << " wrong req" ;
776
+ }
777
+ });
778
+ });
779
+ }
780
+
548
781
NNVM_REGISTER_OP (Embedding)
549
782
.set_attr<FCompute>(" FCompute<gpu>" , EmbeddingOpForward<gpu>)
550
783
.set_attr<FComputeEx>(" FComputeEx<gpu>" , SparseEmbeddingOpForwardEx<gpu>);
0 commit comments