Skip to content

Commit 34d1e73

Browse files
committed
LSTMTrainer: Catch empty vectors
The new test in LSTMTrainer::UpdateErrorGraph fixes an assertion (see issues #644, #792). The new test in LSTMTrainer::ReadTrainingDump was added to improve the robustness of the code. Signed-off-by: Stefan Weil <[email protected]>
1 parent 1e5522d commit 34d1e73

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

lstm/lstmtrainer.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,10 @@ bool LSTMTrainer::SaveTrainingDump(SerializeAmount serialize_amount,
918918
// Reads previously saved trainer from memory.
919919
bool LSTMTrainer::ReadTrainingDump(const GenericVector<char>& data,
920920
LSTMTrainer* trainer) {
921+
if (data.size() == 0) {
922+
tprintf("Warning: data size is zero in LSTMTrainer::ReadTrainingDump\n");
923+
return false;
924+
}
921925
return trainer->ReadSizedTrainingDump(&data[0], data.size());
922926
}
923927

@@ -1298,8 +1302,9 @@ STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate,
12981302
if (error_rate < best_error_rate_) {
12991303
// This is a new (global) minimum.
13001304
if (tester != NULL) {
1301-
result = tester->Run(worst_iteration_, worst_error_rates_,
1302-
worst_model_data_, CurrentTrainingStage());
1305+
if (worst_model_data_.size() != 0)
1306+
result = tester->Run(worst_iteration_, worst_error_rates_,
1307+
worst_model_data_, CurrentTrainingStage());
13031308
worst_model_data_.truncate(0);
13041309
best_model_data_ = model_data;
13051310
}

0 commit comments

Comments
 (0)