Skip to content

Commit 6590194

Browse files
xziyagyshi
authored andcommitted
MKL-DNN RNN checks NDArray version (apache#16071)
* MKL-DNN RNN checks NDArray version * Add UT * Use default_context()
1 parent e09ccbb commit 6590194

File tree

3 files changed

+68
-64
lines changed

3 files changed

+68
-64
lines changed

src/operator/rnn-inl.h

Lines changed: 21 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,11 @@ class RNNOp {
409409
std::vector<mkldnn::memory> bias_memory;
410410
std::vector<mkldnn::memory> y_memory;
411411
std::vector<mkldnn::memory> hcy_memory;
412+
size_t weights_version;
412413
bool has_cache;
413414
bool init_mem_;
414415
size_t reserve_mem_size_;
415-
Storage::Handle mem_space_;
416+
NDArray mem_space_;
416417
#endif
417418
explicit RNNOp(RNNParam param, Context ctx) {
418419
this->param_ = param;
@@ -522,12 +523,6 @@ class RNNOp {
522523
}
523524

524525
~RNNOp() {
525-
#if MXNET_USE_MKLDNN == 1
526-
if (init_mem_) {
527-
Storage::Get()->Free(mem_space_);
528-
init_mem_ = false;
529-
}
530-
#endif // MXNET_USE_MKLDNN
531526
#if MXNET_USE_CUDNN == 1
532527
CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
533528
CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
@@ -560,17 +555,6 @@ class RNNOp {
560555
CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_));
561556
#endif // MXNET_USE_CUDNN_GE_7200
562557
#endif // MXNET_USE_CUDNN
563-
564-
if (ctx_.dev_type == kCPU) {
565-
if (init_space_) {
566-
Storage::Get()->Free(reserve_cpu_space_);
567-
init_space_ = false;
568-
}
569-
if (temp_init_space_) {
570-
Storage::Get()->Free(temp_cpu_space_);
571-
temp_init_space_ = false;
572-
}
573-
}
574558
}
575559

576560
void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
@@ -855,37 +839,30 @@ class RNNOp {
855839
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
856840

857841
if (ctx_.dev_type == kCPU) {
842+
// allocate temp space
843+
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
844+
param_.state_size, direction, param_.mode);
845+
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
846+
temp_cpu_space_size_ = work_cpu_space_size;
847+
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}), ctx_,
848+
false, in_data[rnn_enum::kData].type_flag_);
849+
temp_init_space_ = true;
850+
}
851+
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);
852+
858853
if (ctx.is_train) {
859-
// allocate temp space
860-
const size_t work_cpu_space_size =
861-
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
862-
param_.state_size, direction, param_.mode);
863-
if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
864-
Storage::Get()->Free(temp_cpu_space_);
865-
temp_init_space_ = false;
866-
}
867-
if (!temp_init_space_) {
868-
temp_cpu_space_ = Storage::Get()->Alloc
869-
(work_cpu_space_size * sizeof(DType), Context::CPU());
870-
temp_cpu_space_size_ = work_cpu_space_size;
871-
temp_init_space_ = true;
872-
}
873-
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
854+
// allocate reserve space
874855

875856
const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
876857
param_.seq_length_, param_.batch_size_,
877858
param_.state_size, param_.mode);
878-
if (init_space_ && reserve_cpu_space_size_ < r_size) {
879-
Storage::Get()->Free(reserve_cpu_space_);
880-
init_space_ = false;
881-
}
882-
if (!init_space_) {
883-
reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), Context::CPU());
859+
if (!init_space_ || reserve_cpu_space_size_ < r_size) {
884860
reserve_cpu_space_size_ = r_size;
861+
reserve_cpu_space_ = NDArray(TShape({static_cast<dim_t>(reserve_cpu_space_size_)}), ctx_,
862+
false, in_data[rnn_enum::kData].type_flag_);
885863
init_space_ = true;
886864
}
887-
888-
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr);
865+
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);
889866

890867
RNNForwardTraining<DType>(work_cpu_space,
891868
reserve_space_ptr,
@@ -945,20 +922,6 @@ class RNNOp {
945922
#endif // MXNET_USE_MKLDNN == 1
946923
// Before integrating MKLDNN GRU fp32 inference
947924
// using below code for keep func being OK
948-
const size_t work_cpu_space_size =
949-
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
950-
param_.state_size, direction, param_.mode);
951-
if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
952-
Storage::Get()->Free(temp_cpu_space_);
953-
temp_init_space_ = false;
954-
}
955-
if (!temp_init_space_) {
956-
temp_cpu_space_ = Storage::Get()->Alloc
957-
(work_cpu_space_size * sizeof(DType), Context::CPU());
958-
temp_cpu_space_size_ = work_cpu_space_size;
959-
temp_init_space_ = true;
960-
}
961-
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
962925
RNNForwardInference<DType>(work_cpu_space,
963926
param_.state_outputs,
964927
param_.num_layers,
@@ -1171,7 +1134,7 @@ class RNNOp {
11711134
if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) {
11721135
LOG(FATAL) << "Check temp init error";
11731136
}
1174-
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
1137+
DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.data().dptr_);
11751138
size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
11761139
param_.seq_length_, param_.batch_size_,
11771140
param_.state_size, param_.mode);
@@ -1180,7 +1143,7 @@ class RNNOp {
11801143
LOG(FATAL) << "Check forward init error";
11811144
}
11821145

1183-
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr);
1146+
DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.data().dptr_);
11841147
RNNBackward<DType>(work_cpu_space,
11851148
reserve_space_ptr,
11861149
param_.num_layers,
@@ -1551,7 +1514,7 @@ class RNNOp {
15511514
#endif // MXNET_USE_CUDNN
15521515
bool init_space_, temp_init_space_;
15531516
size_t reserve_cpu_space_size_, temp_cpu_space_size_;
1554-
Storage::Handle reserve_cpu_space_, temp_cpu_space_;
1517+
NDArray reserve_cpu_space_, temp_cpu_space_;
15551518
}; // class RNNOp
15561519

15571520
static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,

src/operator/rnn.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,22 +270,24 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
270270

271271
const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode);
272272
if (op.init_mem_ && op.reserve_mem_size_ < r_size) {
273-
Storage::Get()->Free(op.mem_space_);
274273
op.init_mem_ = false;
275274
}
275+
const size_t weights_version = inputs[rnn_enum::kParams].version();
276276
if (!op.init_mem_) {
277-
op.mem_space_ = Storage::Get()->Alloc(
278-
r_size * sizeof(DType),
279-
Context::CPU());
277+
op.mem_space_ = NDArray(TShape({static_cast<dim_t>(r_size)}), op.ctx_, false, dtype);
280278
op.reserve_mem_size_ = r_size;
281279
op.init_mem_ = true;
282280
op.has_cache = false;
281+
// Assign weights_version
282+
op.weights_version = weights_version;
283283
}
284-
if (op.has_cache && op.x_memory.size() == 0) {
284+
// Check if NDArray was changed.
285+
if (op.weights_version != weights_version) {
285286
op.has_cache = false;
287+
op.weights_version = weights_version;
286288
}
287289

288-
DType* workptr = static_cast<DType*>(op.mem_space_.dptr);
290+
DType* workptr = static_cast<DType*>(op.mem_space_.data().dptr_);
289291
mkldnn::memory::dims src_layer_tz_0 = {T, N, I};
290292
mkldnn::memory::dims src_layer_tz = {T, N, D * H};
291293
mkldnn::memory::dims dst_layer_tz = {T, N, D * H};

tests/python/unittest/test_operator.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,45 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e
7575
assert(mod2.get_input_grads()[0] == None)
7676

7777

78+
@with_seed()
79+
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
80+
def test_rnn_with_new_param():
81+
rnn_modes = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
82+
ngates_ = [1, 1, 3, 4]
83+
num_layers, input_size, seq_len, batch_size, state_size = 3, 128, 5, 64, 8
84+
for bidirectional in [False, True]:
85+
directions = 2 if bidirectional else 1
86+
for mode, ngates in zip(rnn_modes, ngates_):
87+
first_layer_size = (input_size * state_size + state_size * state_size + state_size * 2) * ngates
88+
rest_layer_size = (state_size * directions * state_size + state_size * state_size + state_size * 2) \
89+
* ngates * (num_layers - 1)
90+
param_size = (first_layer_size + rest_layer_size) * directions
91+
sym = mx.sym.RNN(mode=mode, num_layers=num_layers, bidirectional=bidirectional,
92+
state_outputs=False, state_size=state_size, name='rnn')
93+
94+
bind_dict = {
95+
'rnn_data': mx.ndarray.random.uniform(low=-1, high=1, shape=(seq_len, batch_size, input_size)),
96+
'rnn_parameters': mx.ndarray.random.uniform(low=-1, high=1, shape=(param_size)),
97+
'rnn_state': mx.ndarray.zeros(shape=(num_layers * directions, batch_size, state_size))
98+
}
99+
if mode == 'lstm':
100+
bind_dict['rnn_state_cell'] = mx.ndarray.zeros(
101+
shape=(num_layers * directions, batch_size, state_size))
102+
103+
ex = sym.bind(default_context(), bind_dict)
104+
ex.forward(is_train=True)
105+
ex01 = ex.output_dict['rnn_output'].asnumpy()
106+
ex.forward(is_train=False)
107+
ex02 = ex.output_dict['rnn_output'].asnumpy()
108+
assert_allclose(ex01, ex02, rtol=1e-2, atol=1e-4)
109+
bind_dict['rnn_parameters'] = mx.ndarray.random.uniform(low=-1, high=1, shape=(param_size))
110+
ex.copy_params_from(bind_dict)
111+
ex.forward(is_train=True)
112+
ex03 = ex.output_dict['rnn_output'].asnumpy()
113+
ex.forward(is_train=False)
114+
ex04 = ex.output_dict['rnn_output'].asnumpy()
115+
assert_allclose(ex03, ex04, rtol=1e-2, atol=1e-4)
116+
78117

79118
@with_seed()
80119
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')

0 commit comments

Comments
 (0)