|
| 1 | +#ifndef TESSERACT_UNITTEST_LSTM_TEST_H_ |
| 2 | +#define TESSERACT_UNITTEST_LSTM_TEST_H_ |
| 3 | + |
| 4 | +#include <memory> |
| 5 | +#include <string> |
| 6 | +#include <utility> |
| 7 | + |
| 8 | +#include "base/logging.h" |
| 9 | +#include "base/stringprintf.h" |
| 10 | +#include "file/base/file.h" |
| 11 | +#include "file/base/helpers.h" |
| 12 | +#include "file/base/path.h" |
| 13 | +#include "testing/base/public/googletest.h" |
| 14 | +#include "testing/base/public/gunit.h" |
| 15 | +#include "absl/strings/str_cat.h" |
| 16 | +#include "tesseract/ccutil/unicharset.h" |
| 17 | +#include "tesseract/lstm/functions.h" |
| 18 | +#include "tesseract/lstm/lstmtrainer.h" |
| 19 | +#include "tesseract/training/lang_model_helpers.h" |
| 20 | + |
| 21 | +namespace tesseract { |
| 22 | + |
| 23 | +#if DEBUG_DETAIL == 0 |
| 24 | +// Number of iterations to run all the trainers. |
| 25 | +const int kTrainerIterations = 600; |
| 26 | +// Number of iterations between accuracy checks. |
| 27 | +const int kBatchIterations = 100; |
| 28 | +#else |
| 29 | +// Number of iterations to run all the trainers. |
| 30 | +const int kTrainerIterations = 2; |
| 31 | +// Number of iterations between accuracy checks. |
| 32 | +const int kBatchIterations = 1; |
| 33 | +#endif |
| 34 | + |
| 35 | +// The fixture for testing LSTMTrainer. |
| 36 | +class LSTMTrainerTest : public testing::Test { |
| 37 | + protected: |
| 38 | + LSTMTrainerTest() {} |
| 39 | + string TestDataNameToPath(const string& name) { |
| 40 | + return file::JoinPath(FLAGS_test_srcdir, |
| 41 | + "tesseract/testdata/" + name); |
| 42 | + } |
| 43 | + |
| 44 | + void SetupTrainerEng(const string& network_spec, const string& model_name, |
| 45 | + bool recode, bool adam) { |
| 46 | + SetupTrainer(network_spec, model_name, "eng.unicharset", |
| 47 | + "lstm_training.arial.lstmf", recode, adam, 5e-4, false); |
| 48 | + } |
| 49 | + void SetupTrainer(const string& network_spec, const string& model_name, |
| 50 | + const string& unicharset_file, const string& lstmf_file, |
| 51 | + bool recode, bool adam, double learning_rate, |
| 52 | + bool layer_specific) { |
| 53 | + constexpr char kLang[] = "eng"; // Exact value doesn't matter. |
| 54 | + string unicharset_name = TestDataNameToPath(unicharset_file); |
| 55 | + UNICHARSET unicharset; |
| 56 | + ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false)); |
| 57 | + string script_dir = file::JoinPath( |
| 58 | + FLAGS_test_srcdir, "tesseract/training/langdata"); |
| 59 | + GenericVector<STRING> words; |
| 60 | + EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, "", FLAGS_test_tmpdir, |
| 61 | + kLang, !recode, words, words, words, false, |
| 62 | + nullptr, nullptr)); |
| 63 | + string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name); |
| 64 | + string checkpoint_path = model_path + "_checkpoint"; |
| 65 | + trainer_.reset(new LSTMTrainer(nullptr, nullptr, nullptr, nullptr, |
| 66 | + model_path.c_str(), checkpoint_path.c_str(), |
| 67 | + 0, 0)); |
| 68 | + trainer_->InitCharSet(file::JoinPath(FLAGS_test_tmpdir, kLang, |
| 69 | + absl::StrCat(kLang, ".traineddata"))); |
| 70 | + int net_mode = adam ? NF_ADAM : 0; |
| 71 | + // Adam needs a higher learning rate, due to not multiplying the effective |
| 72 | + // rate by 1/(1-momentum). |
| 73 | + if (adam) learning_rate *= 20.0; |
| 74 | + if (layer_specific) net_mode |= NF_LAYER_SPECIFIC_LR; |
| 75 | + EXPECT_TRUE(trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1, |
| 76 | + learning_rate, 0.9, 0.999)); |
| 77 | + GenericVector<STRING> filenames; |
| 78 | + filenames.push_back(STRING(TestDataNameToPath(lstmf_file).c_str())); |
| 79 | + EXPECT_TRUE(trainer_->LoadAllTrainingData(filenames, CS_SEQUENTIAL, false)); |
| 80 | + LOG(INFO) << "Setup network:" << model_name; |
| 81 | + } |
| 82 | + // Trains for a given number of iterations and returns the char error rate. |
| 83 | + double TrainIterations(int max_iterations) { |
| 84 | + int iteration = trainer_->training_iteration(); |
| 85 | + int iteration_limit = iteration + max_iterations; |
| 86 | + double best_error = 100.0; |
| 87 | + do { |
| 88 | + STRING log_str; |
| 89 | + int target_iteration = iteration + kBatchIterations; |
| 90 | + // Train a few. |
| 91 | + double mean_error = 0.0; |
| 92 | + while (iteration < target_iteration && iteration < iteration_limit) { |
| 93 | + trainer_->TrainOnLine(trainer_.get(), false); |
| 94 | + iteration = trainer_->training_iteration(); |
| 95 | + mean_error += trainer_->LastSingleError(ET_CHAR_ERROR); |
| 96 | + } |
| 97 | + trainer_->MaintainCheckpoints(NULL, &log_str); |
| 98 | + iteration = trainer_->training_iteration(); |
| 99 | + mean_error *= 100.0 / kBatchIterations; |
| 100 | + LOG(INFO) << log_str.string(); |
| 101 | + LOG(INFO) << "Batch error = " << mean_error; |
| 102 | + if (mean_error < best_error) best_error = mean_error; |
| 103 | + } while (iteration < iteration_limit); |
| 104 | + LOG(INFO) << "Trainer error rate = " << best_error; |
| 105 | + return best_error; |
| 106 | + } |
| 107 | + // Tests for a given number of iterations and returns the char error rate. |
| 108 | + double TestIterations(int max_iterations) { |
| 109 | + CHECK_GT(max_iterations, 0); |
| 110 | + int iteration = trainer_->sample_iteration(); |
| 111 | + double mean_error = 0.0; |
| 112 | + int error_count = 0; |
| 113 | + while (error_count < max_iterations) { |
| 114 | + const ImageData& trainingdata = |
| 115 | + *trainer_->mutable_training_data()->GetPageBySerial(iteration); |
| 116 | + NetworkIO fwd_outputs, targets; |
| 117 | + if (trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) != |
| 118 | + UNENCODABLE) { |
| 119 | + mean_error += trainer_->NewSingleError(ET_CHAR_ERROR); |
| 120 | + ++error_count; |
| 121 | + } |
| 122 | + trainer_->SetIteration(++iteration); |
| 123 | + } |
| 124 | + mean_error *= 100.0 / max_iterations; |
| 125 | + LOG(INFO) << "Tester error rate = " << mean_error; |
| 126 | + return mean_error; |
| 127 | + } |
| 128 | + // Tests that the current trainer_ can be converted to int mode and still gets |
| 129 | + // within 1% of the error rate. Returns the increase in error from float to |
| 130 | + // int. |
| 131 | + double TestIntMode(int test_iterations) { |
| 132 | + GenericVector<char> trainer_data; |
| 133 | + EXPECT_TRUE(trainer_->SaveTrainingDump(NO_BEST_TRAINER, trainer_.get(), |
| 134 | + &trainer_data)); |
| 135 | + // Get the error on the next few iterations in float mode. |
| 136 | + double float_err = TestIterations(test_iterations); |
| 137 | + // Restore the dump, convert to int and test error on that. |
| 138 | + EXPECT_TRUE(trainer_->ReadTrainingDump(trainer_data, trainer_.get())); |
| 139 | + trainer_->ConvertToInt(); |
| 140 | + double int_err = TestIterations(test_iterations); |
| 141 | + EXPECT_LT(int_err, float_err + 1.0); |
| 142 | + return int_err - float_err; |
| 143 | + } |
| 144 | + // Sets up a trainer with the given language and given recode+ctc condition. |
| 145 | + // It then verifies that the given str encodes and decodes back to the same |
| 146 | + // string. |
| 147 | + void TestEncodeDecode(const string& lang, const string& str, bool recode) { |
| 148 | + string unicharset_name = lang + ".unicharset"; |
| 149 | + SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name, |
| 150 | + "arialuni.kor.lstmf", recode, true, 5e-4, true); |
| 151 | + GenericVector<int> labels; |
| 152 | + EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels)); |
| 153 | + STRING decoded = trainer_->DecodeLabels(labels); |
| 154 | + string decoded_str(&decoded[0], decoded.length()); |
| 155 | + EXPECT_EQ(str, decoded_str); |
| 156 | + } |
| 157 | + // Calls TestEncodeDeode with both recode on and off. |
| 158 | + void TestEncodeDecodeBoth(const string& lang, const string& str) { |
| 159 | + TestEncodeDecode(lang, str, false); |
| 160 | + TestEncodeDecode(lang, str, true); |
| 161 | + } |
| 162 | + |
| 163 | + std::unique_ptr<LSTMTrainer> trainer_; |
| 164 | +}; |
| 165 | + |
| 166 | +} // namespace tesseract. |
| 167 | + |
| 168 | +#endif // THIRD_PARTY_TESSERACT_UNITTEST_LSTM_TEST_H_ |
0 commit comments