Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 0a921a4

Browse files
committed
Fix CPU-only RRNOp Forward
1 parent 567518b commit 0a921a4

File tree

1 file changed

+55
-56
lines changed

1 file changed

+55
-56
lines changed

src/operator/rnn-inl.h

Lines changed: 55 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -842,53 +842,73 @@ class RNNOp {
842842
#endif // MXNET_USE_CUDNN_GE_7200
843843
}
844844
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
845+
#if !defined(__CUDACC__)
846+
int projection_size = 0;
847+
if (param_.projection_size.has_value()) {
848+
projection_size = param_.projection_size.value();
849+
}
845850

846-
if (ctx_.dev_type == kCPU) {
847-
int projection_size = 0;
851+
// allocate temp space
852+
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
853+
param_.state_size, projection_size, direction, param_.mode);
854+
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
855+
temp_cpu_space_size_ = work_cpu_space_size;
856+
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}), ctx_,
857+
false, in_data[rnn_enum::kData].type_flag_);
858+
temp_init_space_ = true;
859+
}
860+
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);
861+
862+
if (ctx.is_train || ctx.need_grad) {
863+
mshadow::Random<cpu, unsigned> *prnd = ctx.requested[0].get_random<xpu, unsigned int>(s);
864+
std::mt19937 &rnd_engine = prnd->GetRndEngine();
865+
866+
// allocate reserve space
848867
if (param_.projection_size.has_value()) {
849-
projection_size = param_.projection_size.value();
868+
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
850869
}
851870

852-
// allocate temp space
853-
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
854-
param_.state_size, projection_size, direction, param_.mode);
855-
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
856-
temp_cpu_space_size_ = work_cpu_space_size;
857-
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}), ctx_,
871+
const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
872+
param_.seq_length_, param_.batch_size_,
873+
param_.state_size, param_.mode);
874+
if (!init_space_ || reserve_cpu_space_size_ < r_size) {
875+
reserve_cpu_space_size_ = r_size;
876+
reserve_cpu_space_ = NDArray(TShape({static_cast<dim_t>(reserve_cpu_space_size_)}), ctx_,
858877
false, in_data[rnn_enum::kData].type_flag_);
859-
temp_init_space_ = true;
878+
init_space_ = true;
860879
}
861-
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);
862-
863-
if (ctx.is_train || ctx.need_grad) {
864-
mshadow::Random<cpu, unsigned> *prnd = ctx.requested[0].get_random<xpu, unsigned int>(s);
865-
std::mt19937 &rnd_engine = prnd->GetRndEngine();
866-
867-
// allocate reserve space
868-
if (param_.projection_size.has_value()) {
869-
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
870-
}
871-
872-
const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
873-
param_.seq_length_, param_.batch_size_,
874-
param_.state_size, param_.mode);
875-
if (!init_space_ || reserve_cpu_space_size_ < r_size) {
876-
reserve_cpu_space_size_ = r_size;
877-
reserve_cpu_space_ = NDArray(TShape({static_cast<dim_t>(reserve_cpu_space_size_)}), ctx_,
878-
false, in_data[rnn_enum::kData].type_flag_);
879-
init_space_ = true;
880-
}
881-
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);
880+
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);
882881

883-
RNNForwardTraining<DType>(work_cpu_space,
884-
reserve_space_ptr,
882+
RNNForwardTraining<DType>(work_cpu_space,
883+
reserve_space_ptr,
884+
param_.state_outputs,
885+
param_.num_layers,
886+
direction,
887+
param_.seq_length_,
888+
param_.batch_size_,
889+
param_.input_size_,
890+
param_.state_size,
891+
x.dptr_,
892+
hx.dptr_,
893+
cx_ptr,
894+
w.dptr_,
895+
b_ptr,
896+
y.dptr_,
897+
hy_ptr,
898+
cy_ptr,
899+
param_.p,
900+
param_.mode,
901+
rnd_engine);
902+
} else {
903+
RNNForwardInference<DType>(work_cpu_space,
885904
param_.state_outputs,
886905
param_.num_layers,
887906
direction,
888907
param_.seq_length_,
889908
param_.batch_size_,
890909
param_.input_size_,
891910
param_.state_size,
911+
projection_size,
892912
x.dptr_,
893913
hx.dptr_,
894914
cx_ptr,
@@ -897,30 +917,9 @@ class RNNOp {
897917
y.dptr_,
898918
hy_ptr,
899919
cy_ptr,
900-
param_.p,
901-
param_.mode,
902-
rnd_engine);
903-
} else {
904-
RNNForwardInference<DType>(work_cpu_space,
905-
param_.state_outputs,
906-
param_.num_layers,
907-
direction,
908-
param_.seq_length_,
909-
param_.batch_size_,
910-
param_.input_size_,
911-
param_.state_size,
912-
projection_size,
913-
x.dptr_,
914-
hx.dptr_,
915-
cx_ptr,
916-
w.dptr_,
917-
b_ptr,
918-
y.dptr_,
919-
hy_ptr,
920-
cy_ptr,
921-
param_.mode);
922-
}
920+
param_.mode);
923921
}
922+
#endif // !defined(__CUDACC__)
924923
}
925924

926925
void Backward(const OpContext &ctx,

0 commit comments

Comments
 (0)