Skip to content

Commit 4fa463c

Browse files
committed
Corrected SetEnableTraining for recovery from a recognize-only model.
1 parent 006a56c commit 4fa463c

File tree

6 files changed

+29
-16
lines changed

6 files changed

+29
-16
lines changed

lstm/fullyconnected.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,17 @@ StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const {
5656
return result;
5757
}
5858

59-
// Suspends/Enables training by setting the training_ flag. Serialize and
60-
// DeSerialize only operate on the run-time data if state is false.
59+
// Suspends/Enables training by setting the training_ flag.
6160
void FullyConnected::SetEnableTraining(TrainingState state) {
6261
if (state == TS_RE_ENABLE) {
63-
if (training_ == TS_DISABLED) weights_.InitBackward(false);
64-
training_ = TS_ENABLED;
62+
// Enable only from temp disabled.
63+
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
64+
} else if (state == TS_TEMP_DISABLE) {
65+
// Temp disable only from enabled.
66+
if (training_ == TS_ENABLED) training_ = state;
6567
} else {
68+
if (state == TS_ENABLED && training_ == TS_DISABLED)
69+
weights_.InitBackward();
6670
training_ = state;
6771
}
6872
}

lstm/lstm.cpp

+8-4
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,18 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
107107
// DeSerialize only operate on the run-time data if state is false.
108108
void LSTM::SetEnableTraining(TrainingState state) {
109109
if (state == TS_RE_ENABLE) {
110-
if (training_ == TS_DISABLED) {
110+
// Enable only from temp disabled.
111+
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
112+
} else if (state == TS_TEMP_DISABLE) {
113+
// Temp disable only from enabled.
114+
if (training_ == TS_ENABLED) training_ = state;
115+
} else {
116+
if (state == TS_ENABLED && training_ == TS_DISABLED) {
111117
for (int w = 0; w < WT_COUNT; ++w) {
112118
if (w == GFS && !Is2D()) continue;
113-
gate_weights_[w].InitBackward(false);
119+
gate_weights_[w].InitBackward();
114120
}
115121
}
116-
training_ = TS_ENABLED;
117-
} else {
118122
training_ = state;
119123
}
120124
if (softmax_ != NULL) softmax_->SetEnableTraining(state);

lstm/network.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,11 @@ Network::~Network() {
111111
// recognizer can be converted back to a trainer.
112112
void Network::SetEnableTraining(TrainingState state) {
113113
if (state == TS_RE_ENABLE) {
114-
training_ = TS_ENABLED;
114+
// Enable only from temp disabled.
115+
if (training_ == TS_TEMP_DISABLE) training_ = TS_ENABLED;
116+
} else if (state == TS_TEMP_DISABLE) {
117+
// Temp disable only from enabled.
118+
if (training_ == TS_ENABLED) training_ = state;
115119
} else {
116120
training_ = state;
117121
}

lstm/network.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,10 @@ enum TrainingState {
9393
// Valid states of training_.
9494
TS_DISABLED, // Disabled permanently.
9595
TS_ENABLED, // Enabled for backprop and to write a training dump.
96+
// Re-enable from ANY disabled state.
9697
TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
9798
// Valid only for SetEnableTraining.
98-
TS_RE_ENABLE, // Re-Enable whatever the current state.
99+
TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
99100
};
100101

101102
// Base class for network types. Not quite an abstract base class, but almost.

lstm/weightmatrix.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ int WeightMatrix::InitWeightsFloat(int no, int ni, bool ada_grad,
4747
}
4848
}
4949
}
50-
InitBackward(ada_grad);
50+
use_ada_grad_ = ada_grad;
51+
InitBackward();
5152
return ni * no;
5253
}
5354

@@ -83,10 +84,9 @@ void WeightMatrix::ConvertToInt() {
8384

8485
// Allocates any needed memory for running Backward, and zeroes the deltas,
8586
// thus eliminating any existing momentum.
86-
void WeightMatrix::InitBackward(bool ada_grad) {
87+
void WeightMatrix::InitBackward() {
8788
int no = int_mode_ ? wi_.dim1() : wf_.dim1();
8889
int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
89-
use_ada_grad_ = ada_grad;
9090
dw_.Resize(no, ni, 0.0);
9191
updates_.Resize(no, ni, 0.0);
9292
wf_t_.Transpose(wf_);
@@ -134,7 +134,7 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
134134
} else {
135135
if (!wf_.DeSerialize(fp)) return false;
136136
if (training) {
137-
InitBackward(use_ada_grad_);
137+
InitBackward();
138138
if (!updates_.DeSerialize(fp)) return false;
139139
if (use_ada_grad_ && !dw_sq_sum_.DeSerialize(fp)) return false;
140140
}
@@ -157,7 +157,7 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile* fp) {
157157
FloatToDouble(float_array, &wf_);
158158
}
159159
if (training) {
160-
InitBackward(use_ada_grad_);
160+
InitBackward();
161161
if (!float_array.DeSerialize(fp)) return false;
162162
FloatToDouble(float_array, &updates_);
163163
// Errs was only used in int training, which is now dead.

lstm/weightmatrix.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class WeightMatrix {
9292

9393
// Allocates any needed memory for running Backward, and zeroes the deltas,
9494
// thus eliminating any existing momentum.
95-
void InitBackward(bool ada_grad);
95+
void InitBackward();
9696

9797
// Writes to the given file. Returns false in case of error.
9898
bool Serialize(bool training, TFile* fp) const;

0 commit comments

Comments
 (0)