@@ -842,53 +842,73 @@ class RNNOp {
842
842
#endif // MXNET_USE_CUDNN_GE_7200
843
843
}
844
844
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
845
+ #if !defined(__CUDACC__)
846
+ int projection_size = 0 ;
847
+ if (param_.projection_size .has_value ()) {
848
+ projection_size = param_.projection_size .value ();
849
+ }
845
850
846
- if (ctx_.dev_type == kCPU ) {
847
- int projection_size = 0 ;
851
+ // allocate temp space
852
+ const size_t work_cpu_space_size = GetRNNWorkspaceSize (param_.seq_length_ , param_.batch_size_ ,
853
+ param_.state_size , projection_size, direction, param_.mode );
854
+ if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
855
+ temp_cpu_space_size_ = work_cpu_space_size;
856
+ temp_cpu_space_ = NDArray (TShape ({static_cast <dim_t >(temp_cpu_space_size_)}), ctx_,
857
+ false , in_data[rnn_enum::kData ].type_flag_ );
858
+ temp_init_space_ = true ;
859
+ }
860
+ DType* work_cpu_space = static_cast <DType*>(temp_cpu_space_.data ().dptr_ );
861
+
862
+ if (ctx.is_train || ctx.need_grad ) {
863
+ mshadow::Random<cpu, unsigned > *prnd = ctx.requested [0 ].get_random <xpu, unsigned int >(s);
864
+ std::mt19937 &rnd_engine = prnd->GetRndEngine ();
865
+
866
+ // allocate reserve space
848
867
if (param_.projection_size .has_value ()) {
849
- projection_size = param_. projection_size . value () ;
868
+ LOG (FATAL) << " No training support for LSTM with projection on CPU currently. " ;
850
869
}
851
870
852
- // allocate temp space
853
- const size_t work_cpu_space_size = GetRNNWorkspaceSize ( param_.seq_length_ , param_.batch_size_ ,
854
- param_.state_size , projection_size, direction , param_.mode );
855
- if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size ) {
856
- temp_cpu_space_size_ = work_cpu_space_size ;
857
- temp_cpu_space_ = NDArray (TShape ({static_cast <dim_t >(temp_cpu_space_size_ )}), ctx_,
871
+ const size_t r_size = GetRNNReserveSpaceSize (param_. num_layers , direction,
872
+ param_.seq_length_ , param_.batch_size_ ,
873
+ param_.state_size , param_.mode );
874
+ if (!init_space_ || reserve_cpu_space_size_ < r_size ) {
875
+ reserve_cpu_space_size_ = r_size ;
876
+ reserve_cpu_space_ = NDArray (TShape ({static_cast <dim_t >(reserve_cpu_space_size_ )}), ctx_,
858
877
false , in_data[rnn_enum::kData ].type_flag_ );
859
- temp_init_space_ = true ;
878
+ init_space_ = true ;
860
879
}
861
- DType* work_cpu_space = static_cast <DType*>(temp_cpu_space_.data ().dptr_ );
862
-
863
- if (ctx.is_train || ctx.need_grad ) {
864
- mshadow::Random<cpu, unsigned > *prnd = ctx.requested [0 ].get_random <xpu, unsigned int >(s);
865
- std::mt19937 &rnd_engine = prnd->GetRndEngine ();
866
-
867
- // allocate reserve space
868
- if (param_.projection_size .has_value ()) {
869
- LOG (FATAL) << " No training support for LSTM with projection on CPU currently." ;
870
- }
871
-
872
- const size_t r_size = GetRNNReserveSpaceSize (param_.num_layers , direction,
873
- param_.seq_length_ , param_.batch_size_ ,
874
- param_.state_size , param_.mode );
875
- if (!init_space_ || reserve_cpu_space_size_ < r_size) {
876
- reserve_cpu_space_size_ = r_size;
877
- reserve_cpu_space_ = NDArray (TShape ({static_cast <dim_t >(reserve_cpu_space_size_)}), ctx_,
878
- false , in_data[rnn_enum::kData ].type_flag_ );
879
- init_space_ = true ;
880
- }
881
- DType* reserve_space_ptr = static_cast <DType*>(reserve_cpu_space_.data ().dptr_ );
880
+ DType* reserve_space_ptr = static_cast <DType*>(reserve_cpu_space_.data ().dptr_ );
882
881
883
- RNNForwardTraining<DType>(work_cpu_space,
884
- reserve_space_ptr,
882
+ RNNForwardTraining<DType>(work_cpu_space,
883
+ reserve_space_ptr,
884
+ param_.state_outputs ,
885
+ param_.num_layers ,
886
+ direction,
887
+ param_.seq_length_ ,
888
+ param_.batch_size_ ,
889
+ param_.input_size_ ,
890
+ param_.state_size ,
891
+ x.dptr_ ,
892
+ hx.dptr_ ,
893
+ cx_ptr,
894
+ w.dptr_ ,
895
+ b_ptr,
896
+ y.dptr_ ,
897
+ hy_ptr,
898
+ cy_ptr,
899
+ param_.p ,
900
+ param_.mode ,
901
+ rnd_engine);
902
+ } else {
903
+ RNNForwardInference<DType>(work_cpu_space,
885
904
param_.state_outputs ,
886
905
param_.num_layers ,
887
906
direction,
888
907
param_.seq_length_ ,
889
908
param_.batch_size_ ,
890
909
param_.input_size_ ,
891
910
param_.state_size ,
911
+ projection_size,
892
912
x.dptr_ ,
893
913
hx.dptr_ ,
894
914
cx_ptr,
@@ -897,30 +917,9 @@ class RNNOp {
897
917
y.dptr_ ,
898
918
hy_ptr,
899
919
cy_ptr,
900
- param_.p ,
901
- param_.mode ,
902
- rnd_engine);
903
- } else {
904
- RNNForwardInference<DType>(work_cpu_space,
905
- param_.state_outputs ,
906
- param_.num_layers ,
907
- direction,
908
- param_.seq_length_ ,
909
- param_.batch_size_ ,
910
- param_.input_size_ ,
911
- param_.state_size ,
912
- projection_size,
913
- x.dptr_ ,
914
- hx.dptr_ ,
915
- cx_ptr,
916
- w.dptr_ ,
917
- b_ptr,
918
- y.dptr_ ,
919
- hy_ptr,
920
- cy_ptr,
921
- param_.mode );
922
- }
920
+ param_.mode );
923
921
}
922
+ #endif // !defined(__CUDACC__)
924
923
}
925
924
926
925
void Backward (const OpContext &ctx,
0 commit comments