Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 582489c

Browse files
ZhennanQinpengzhao-intel
authored andcommitted
Fix Cached_op with static_shape=true (#15298)
* Fix * run ci
1 parent ba30644 commit 582489c

File tree

2 files changed

+24
-30
lines changed

2 files changed

+24
-30
lines changed

src/imperative/cached_op.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ struct CachedOp::CachedOpState {
8181

8282
std::vector<NDArray> buff;
8383
std::vector<NDArray*> arrays;
84+
std::vector<NDArray*> arrays_with_in_out;
8485
std::vector<OpReqType> array_reqs;
8586

8687
std::vector<OpStatePtr> op_states;
@@ -762,7 +763,8 @@ OpStatePtr CachedOp::StaticForward(
762763
// We are going to add input and output arrays to the array list.
763764
// The input and output arrays should only be valid for this run,
764765
// so we shouldn't modify the state's array list.
765-
auto arrays = state.arrays;
766+
state.arrays_with_in_out = state.arrays;
767+
auto& arrays = state.arrays_with_in_out;
766768
if (config_.static_shape) {
767769
for (auto i : config_.param_indices) {
768770
auto nid = idx.input_nodes()[i];
@@ -1063,7 +1065,8 @@ void CachedOp::StaticBackward(
10631065
// We are going to add input and output arrays to the array list.
10641066
// The input and output arrays should only be valid for this run,
10651067
// so we shouldn't modify the state's array list.
1066-
auto arrays = state.arrays;
1068+
state.arrays_with_in_out = state.arrays;
1069+
auto& arrays = state.arrays_with_in_out;
10671070
for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) {
10681071
auto eid = state.info.bwd_input_eid[i];
10691072
if (eid == kEidNotExist) {

src/nnvm/legacy_op_util.cc

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ class OperatorState {
7979
public:
8080
OperatorState(Operator *opr, const OperatorProperty *prop) {
8181
opr_ = opr;
82-
fwd_init_ = bwd_init_ = false;
8382

8483
in_data_fwd_.resize(prop->ListArguments().size());
8584
in_data_bwd_.resize(prop->ListArguments().size());
@@ -110,47 +109,39 @@ class OperatorState {
110109
const std::vector<TBlob>& inputs,
111110
const std::vector<OpReqType>& req,
112111
const std::vector<TBlob>& outputs) {
113-
if (!fwd_init_) {
114-
CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
115-
CHECK_EQ(outputs.size(), out_data_.size());
116-
// in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
117-
// referred by arg_data_ptr_ will be overriden
118-
for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
119-
for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
120-
for (size_t i = 0; i < aux_data_.size(); ++i) {
121-
aux_data_[i] = inputs[i + in_data_fwd_.size()];
122-
}
123-
for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
124-
fwd_init_ = true;
112+
CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
113+
CHECK_EQ(outputs.size(), out_data_.size());
114+
// in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
115+
// referred by arg_data_ptr_ will be overriden
116+
for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
117+
for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
118+
for (size_t i = 0; i < aux_data_.size(); ++i) {
119+
aux_data_[i] = inputs[i + in_data_fwd_.size()];
125120
}
121+
for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
126122
opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
127123
}
128124

129125
void Backward(const OpContext &ctx,
130126
const std::vector<TBlob>& inputs,
131127
const std::vector<OpReqType>& req,
132128
const std::vector<TBlob>& outputs) {
133-
if (!bwd_init_) {
134-
CHECK(fwd_init_);
135-
CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
136-
// override tblobs pointed by arg_data_ptr_ since they might not contain
137-
// initialized data during forward pass.
138-
for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
139-
*arg_data_ptr_[i] = inputs[i];
140-
}
141-
for (size_t i = 0; i < aux_data_.size(); ++i) {
142-
aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
143-
}
144-
CHECK_EQ(outputs.size(), in_grad_.size());
145-
for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
146-
bwd_init_ = true;
129+
CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
130+
// override tblobs pointed by arg_data_ptr_ since they might not contain
131+
// initialized data during forward pass.
132+
for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
133+
*arg_data_ptr_[i] = inputs[i];
134+
}
135+
for (size_t i = 0; i < aux_data_.size(); ++i) {
136+
aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
147137
}
138+
CHECK_EQ(outputs.size(), in_grad_.size());
139+
for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
148140
opr_->Backward(ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_);
149141
}
150142

151143
private:
152144
Operator *opr_;
153-
bool fwd_init_, bwd_init_;
154145
// input data blobs for forward and backward
155146
// in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor
156147
// performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is

0 commit comments

Comments
 (0)