@@ -102,6 +102,23 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
102
102
return result;
103
103
}
104
104
105
+ // Suspends/Enables training by setting the training_ flag. Serialize and
106
+ // DeSerialize only operate on the run-time data if state is false.
107
+ void LSTM::SetEnableTraining (TrainingState state) {
108
+ if (state == TS_RE_ENABLE) {
109
+ if (training_ == TS_DISABLED) {
110
+ for (int w = 0 ; w < WT_COUNT; ++w) {
111
+ if (w == GFS && !Is2D ()) continue ;
112
+ gate_weights_[w].InitBackward (false );
113
+ }
114
+ }
115
+ training_ = TS_ENABLED;
116
+ } else {
117
+ training_ = state;
118
+ }
119
+ if (softmax_ != NULL ) softmax_->SetEnableTraining (state);
120
+ }
121
+
105
122
// Sets up the network for training. Initializes weights using weights of
106
123
// scale `range` picked according to the random number generator `randomizer`.
107
124
int LSTM::InitWeights (float range, TRand* randomizer) {
@@ -148,7 +165,7 @@ bool LSTM::Serialize(TFile* fp) const {
148
165
if (fp->FWrite (&na_, sizeof (na_), 1 ) != 1 ) return false ;
149
166
for (int w = 0 ; w < WT_COUNT; ++w) {
150
167
if (w == GFS && !Is2D ()) continue ;
151
- if (!gate_weights_[w].Serialize (training_ , fp)) return false ;
168
+ if (!gate_weights_[w].Serialize (IsTraining () , fp)) return false ;
152
169
}
153
170
if (softmax_ != NULL && !softmax_->Serialize (fp)) return false ;
154
171
return true ;
@@ -169,7 +186,7 @@ bool LSTM::DeSerialize(bool swap, TFile* fp) {
169
186
is_2d_ = false ;
170
187
for (int w = 0 ; w < WT_COUNT; ++w) {
171
188
if (w == GFS && !Is2D ()) continue ;
172
- if (!gate_weights_[w].DeSerialize (training_ , swap, fp)) return false ;
189
+ if (!gate_weights_[w].DeSerialize (IsTraining () , swap, fp)) return false ;
173
190
if (w == CI) {
174
191
ns_ = gate_weights_[CI].NumOutputs ();
175
192
is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
@@ -322,7 +339,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
322
339
MultiplyAccumulate (ns_, temp_lines[CI], temp_lines[GI], curr_state);
323
340
// Clip curr_state to a sane range.
324
341
ClipVector<double >(ns_, -kStateClip , kStateClip , curr_state);
325
- if (training_ ) {
342
+ if (IsTraining () ) {
326
343
// Save the gate node values.
327
344
node_values_[CI].WriteTimeStep (t, temp_lines[CI]);
328
345
node_values_[GI].WriteTimeStep (t, temp_lines[GI]);
@@ -331,7 +348,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
331
348
if (Is2D ()) node_values_[GFS].WriteTimeStep (t, temp_lines[GFS]);
332
349
}
333
350
FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
334
- if (training_ ) state_.WriteTimeStep (t, curr_state);
351
+ if (IsTraining () ) state_.WriteTimeStep (t, curr_state);
335
352
if (softmax_ != NULL ) {
336
353
if (input.int_mode ()) {
337
354
int_output->WriteTimeStep (0 , curr_output);
@@ -697,7 +714,7 @@ void LSTM::PrintDW() {
697
714
void LSTM::ResizeForward (const NetworkIO& input) {
698
715
source_.Resize (input, na_);
699
716
which_fg_.ResizeNoInit (input.Width (), ns_);
700
- if (training_ ) {
717
+ if (IsTraining () ) {
701
718
state_.ResizeFloat (input, ns_);
702
719
for (int w = 0 ; w < WT_COUNT; ++w) {
703
720
if (w == GFS && !Is2D ()) continue ;
0 commit comments