@@ -79,7 +79,6 @@ class OperatorState {
79
79
public:
80
80
OperatorState (Operator *opr, const OperatorProperty *prop) {
81
81
opr_ = opr;
82
- fwd_init_ = bwd_init_ = false ;
83
82
84
83
in_data_fwd_.resize (prop->ListArguments ().size ());
85
84
in_data_bwd_.resize (prop->ListArguments ().size ());
@@ -110,47 +109,39 @@ class OperatorState {
110
109
const std::vector<TBlob>& inputs,
111
110
const std::vector<OpReqType>& req,
112
111
const std::vector<TBlob>& outputs) {
113
- if (!fwd_init_) {
114
- CHECK_EQ (inputs.size (), in_data_fwd_.size () + aux_data_.size ());
115
- CHECK_EQ (outputs.size (), out_data_.size ());
116
- // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
117
- // referred by arg_data_ptr_ will be overriden
118
- for (size_t i = 0 ; i < in_data_fwd_.size (); ++i) in_data_fwd_[i] = inputs[i];
119
- for (size_t i = 0 ; i < in_data_fwd_.size (); ++i) in_data_bwd_[i] = inputs[i];
120
- for (size_t i = 0 ; i < aux_data_.size (); ++i) {
121
- aux_data_[i] = inputs[i + in_data_fwd_.size ()];
122
- }
123
- for (size_t i = 0 ; i < out_data_.size (); ++i) out_data_[i] = outputs[i];
124
- fwd_init_ = true ;
112
+ CHECK_EQ (inputs.size (), in_data_fwd_.size () + aux_data_.size ());
113
+ CHECK_EQ (outputs.size (), out_data_.size ());
114
+ // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
115
+ // referred by arg_data_ptr_ will be overriden
116
+ for (size_t i = 0 ; i < in_data_fwd_.size (); ++i) in_data_fwd_[i] = inputs[i];
117
+ for (size_t i = 0 ; i < in_data_fwd_.size (); ++i) in_data_bwd_[i] = inputs[i];
118
+ for (size_t i = 0 ; i < aux_data_.size (); ++i) {
119
+ aux_data_[i] = inputs[i + in_data_fwd_.size ()];
125
120
}
121
+ for (size_t i = 0 ; i < out_data_.size (); ++i) out_data_[i] = outputs[i];
126
122
opr_->Forward (ctx, in_data_fwd_, req, out_data_, aux_data_);
127
123
}
128
124
129
125
void Backward (const OpContext &ctx,
130
126
const std::vector<TBlob>& inputs,
131
127
const std::vector<OpReqType>& req,
132
128
const std::vector<TBlob>& outputs) {
133
- if (!bwd_init_) {
134
- CHECK (fwd_init_);
135
- CHECK_EQ (arg_data_ptr_.size () + aux_data_.size (), inputs.size ());
136
- // override tblobs pointed by arg_data_ptr_ since they might not contain
137
- // initialized data during forward pass.
138
- for (size_t i = 0 ; i < arg_data_ptr_.size (); ++i) {
139
- *arg_data_ptr_[i] = inputs[i];
140
- }
141
- for (size_t i = 0 ; i < aux_data_.size (); ++i) {
142
- aux_data_[i] = inputs[inputs.size () - aux_data_.size () + i];
143
- }
144
- CHECK_EQ (outputs.size (), in_grad_.size ());
145
- for (size_t i = 0 ; i < outputs.size (); ++i) in_grad_[i] = outputs[i];
146
- bwd_init_ = true ;
129
+ CHECK_EQ (arg_data_ptr_.size () + aux_data_.size (), inputs.size ());
130
+ // override tblobs pointed by arg_data_ptr_ since they might not contain
131
+ // initialized data during forward pass.
132
+ for (size_t i = 0 ; i < arg_data_ptr_.size (); ++i) {
133
+ *arg_data_ptr_[i] = inputs[i];
134
+ }
135
+ for (size_t i = 0 ; i < aux_data_.size (); ++i) {
136
+ aux_data_[i] = inputs[inputs.size () - aux_data_.size () + i];
147
137
}
138
+ CHECK_EQ (outputs.size (), in_grad_.size ());
139
+ for (size_t i = 0 ; i < outputs.size (); ++i) in_grad_[i] = outputs[i];
148
140
opr_->Backward (ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_);
149
141
}
150
142
151
143
private:
152
144
Operator *opr_;
153
- bool fwd_init_, bwd_init_;
154
145
// input data blobs for forward and backward
155
146
// in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor
156
147
// performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is
0 commit comments