Skip to content

Commit d04133f

Browse files
committed
unittest: Add more files from Google
They were provided by Jeff Breidenbach <[email protected]>. Signed-off-by: Stefan Weil <[email protected]>
1 parent 9c2d1aa commit d04133f

File tree

2 files changed

+239
-0
lines changed

2 files changed

+239
-0
lines changed

unittest/lstm_test.h

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#ifndef TESSERACT_UNITTEST_LSTM_TEST_H_
2+
#define TESSERACT_UNITTEST_LSTM_TEST_H_
3+
4+
#include <memory>
5+
#include <string>
6+
#include <utility>
7+
8+
#include "base/logging.h"
9+
#include "base/stringprintf.h"
10+
#include "file/base/file.h"
11+
#include "file/base/helpers.h"
12+
#include "file/base/path.h"
13+
#include "testing/base/public/googletest.h"
14+
#include "testing/base/public/gunit.h"
15+
#include "absl/strings/str_cat.h"
16+
#include "tesseract/ccutil/unicharset.h"
17+
#include "tesseract/lstm/functions.h"
18+
#include "tesseract/lstm/lstmtrainer.h"
19+
#include "tesseract/training/lang_model_helpers.h"
20+
21+
namespace tesseract {
22+
23+
#if DEBUG_DETAIL == 0
24+
// Number of iterations to run all the trainers.
25+
const int kTrainerIterations = 600;
26+
// Number of iterations between accuracy checks.
27+
const int kBatchIterations = 100;
28+
#else
29+
// Number of iterations to run all the trainers.
30+
const int kTrainerIterations = 2;
31+
// Number of iterations between accuracy checks.
32+
const int kBatchIterations = 1;
33+
#endif
34+
35+
// The fixture for testing LSTMTrainer.
36+
class LSTMTrainerTest : public testing::Test {
37+
protected:
38+
LSTMTrainerTest() {}
39+
string TestDataNameToPath(const string& name) {
40+
return file::JoinPath(FLAGS_test_srcdir,
41+
"tesseract/testdata/" + name);
42+
}
43+
44+
void SetupTrainerEng(const string& network_spec, const string& model_name,
45+
bool recode, bool adam) {
46+
SetupTrainer(network_spec, model_name, "eng.unicharset",
47+
"lstm_training.arial.lstmf", recode, adam, 5e-4, false);
48+
}
49+
void SetupTrainer(const string& network_spec, const string& model_name,
50+
const string& unicharset_file, const string& lstmf_file,
51+
bool recode, bool adam, double learning_rate,
52+
bool layer_specific) {
53+
constexpr char kLang[] = "eng"; // Exact value doesn't matter.
54+
string unicharset_name = TestDataNameToPath(unicharset_file);
55+
UNICHARSET unicharset;
56+
ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false));
57+
string script_dir = file::JoinPath(
58+
FLAGS_test_srcdir, "tesseract/training/langdata");
59+
GenericVector<STRING> words;
60+
EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, "", FLAGS_test_tmpdir,
61+
kLang, !recode, words, words, words, false,
62+
nullptr, nullptr));
63+
string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name);
64+
string checkpoint_path = model_path + "_checkpoint";
65+
trainer_.reset(new LSTMTrainer(nullptr, nullptr, nullptr, nullptr,
66+
model_path.c_str(), checkpoint_path.c_str(),
67+
0, 0));
68+
trainer_->InitCharSet(file::JoinPath(FLAGS_test_tmpdir, kLang,
69+
absl::StrCat(kLang, ".traineddata")));
70+
int net_mode = adam ? NF_ADAM : 0;
71+
// Adam needs a higher learning rate, due to not multiplying the effective
72+
// rate by 1/(1-momentum).
73+
if (adam) learning_rate *= 20.0;
74+
if (layer_specific) net_mode |= NF_LAYER_SPECIFIC_LR;
75+
EXPECT_TRUE(trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1,
76+
learning_rate, 0.9, 0.999));
77+
GenericVector<STRING> filenames;
78+
filenames.push_back(STRING(TestDataNameToPath(lstmf_file).c_str()));
79+
EXPECT_TRUE(trainer_->LoadAllTrainingData(filenames, CS_SEQUENTIAL, false));
80+
LOG(INFO) << "Setup network:" << model_name;
81+
}
82+
// Trains for a given number of iterations and returns the char error rate.
83+
double TrainIterations(int max_iterations) {
84+
int iteration = trainer_->training_iteration();
85+
int iteration_limit = iteration + max_iterations;
86+
double best_error = 100.0;
87+
do {
88+
STRING log_str;
89+
int target_iteration = iteration + kBatchIterations;
90+
// Train a few.
91+
double mean_error = 0.0;
92+
while (iteration < target_iteration && iteration < iteration_limit) {
93+
trainer_->TrainOnLine(trainer_.get(), false);
94+
iteration = trainer_->training_iteration();
95+
mean_error += trainer_->LastSingleError(ET_CHAR_ERROR);
96+
}
97+
trainer_->MaintainCheckpoints(NULL, &log_str);
98+
iteration = trainer_->training_iteration();
99+
mean_error *= 100.0 / kBatchIterations;
100+
LOG(INFO) << log_str.string();
101+
LOG(INFO) << "Batch error = " << mean_error;
102+
if (mean_error < best_error) best_error = mean_error;
103+
} while (iteration < iteration_limit);
104+
LOG(INFO) << "Trainer error rate = " << best_error;
105+
return best_error;
106+
}
107+
// Tests for a given number of iterations and returns the char error rate.
108+
double TestIterations(int max_iterations) {
109+
CHECK_GT(max_iterations, 0);
110+
int iteration = trainer_->sample_iteration();
111+
double mean_error = 0.0;
112+
int error_count = 0;
113+
while (error_count < max_iterations) {
114+
const ImageData& trainingdata =
115+
*trainer_->mutable_training_data()->GetPageBySerial(iteration);
116+
NetworkIO fwd_outputs, targets;
117+
if (trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) !=
118+
UNENCODABLE) {
119+
mean_error += trainer_->NewSingleError(ET_CHAR_ERROR);
120+
++error_count;
121+
}
122+
trainer_->SetIteration(++iteration);
123+
}
124+
mean_error *= 100.0 / max_iterations;
125+
LOG(INFO) << "Tester error rate = " << mean_error;
126+
return mean_error;
127+
}
128+
// Tests that the current trainer_ can be converted to int mode and still gets
129+
// within 1% of the error rate. Returns the increase in error from float to
130+
// int.
131+
double TestIntMode(int test_iterations) {
132+
GenericVector<char> trainer_data;
133+
EXPECT_TRUE(trainer_->SaveTrainingDump(NO_BEST_TRAINER, trainer_.get(),
134+
&trainer_data));
135+
// Get the error on the next few iterations in float mode.
136+
double float_err = TestIterations(test_iterations);
137+
// Restore the dump, convert to int and test error on that.
138+
EXPECT_TRUE(trainer_->ReadTrainingDump(trainer_data, trainer_.get()));
139+
trainer_->ConvertToInt();
140+
double int_err = TestIterations(test_iterations);
141+
EXPECT_LT(int_err, float_err + 1.0);
142+
return int_err - float_err;
143+
}
144+
// Sets up a trainer with the given language and given recode+ctc condition.
145+
// It then verifies that the given str encodes and decodes back to the same
146+
// string.
147+
void TestEncodeDecode(const string& lang, const string& str, bool recode) {
148+
string unicharset_name = lang + ".unicharset";
149+
SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name,
150+
"arialuni.kor.lstmf", recode, true, 5e-4, true);
151+
GenericVector<int> labels;
152+
EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
153+
STRING decoded = trainer_->DecodeLabels(labels);
154+
string decoded_str(&decoded[0], decoded.length());
155+
EXPECT_EQ(str, decoded_str);
156+
}
157+
// Calls TestEncodeDeode with both recode on and off.
158+
void TestEncodeDecodeBoth(const string& lang, const string& str) {
159+
TestEncodeDecode(lang, str, false);
160+
TestEncodeDecode(lang, str, true);
161+
}
162+
163+
std::unique_ptr<LSTMTrainer> trainer_;
164+
};
165+
166+
} // namespace tesseract.
167+
168+
#endif // THIRD_PARTY_TESSERACT_UNITTEST_LSTM_TEST_H_

unittest/normstrngs_test.h

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#ifndef TESSERACT_UNITTEST_NORMSTRNGS_TEST_H_
2+
#define TESSERACT_UNITTEST_NORMSTRNGS_TEST_H_
3+
4+
#include <string>
5+
#include <vector>
6+
#include "base/stringprintf.h"
7+
#include "absl/strings/str_cat.h"
8+
#include "absl/strings/str_join.h"
9+
#include "tesseract/ccutil/unichar.h"
10+
11+
namespace tesseract {
12+
13+
inline string CodepointList(const std::vector<char32>& str32) {
14+
string result;
15+
int total_chars = str32.size();
16+
for (int i = 0; i < total_chars; ++i) {
17+
StringAppendF(&result, "[%x]", str32[i]);
18+
}
19+
return result;
20+
}
21+
22+
inline string PrintString32WithUnicodes(const string& str) {
23+
std::vector<char32> str32 = UNICHAR::UTF8ToUTF32(str.c_str());
24+
return absl::StrCat("\"", str, "\" ", CodepointList(str32));
25+
}
26+
27+
inline string PrintStringVectorWithUnicodes(const std::vector<string>& glyphs) {
28+
string result;
29+
for (const auto& s : glyphs) {
30+
absl::StrAppend(&result, "Glyph:", PrintString32WithUnicodes(s), "\n");
31+
}
32+
return result;
33+
}
34+
35+
inline void ExpectGraphemeModeResults(const string& str, UnicodeNormMode u_mode,
36+
int unicode_count, int glyph_count,
37+
int grapheme_count,
38+
const string& target_str) {
39+
std::vector<string> glyphs;
40+
EXPECT_TRUE(NormalizeCleanAndSegmentUTF8(
41+
u_mode, OCRNorm::kNone, GraphemeNormMode::kIndividualUnicodes, true,
42+
str.c_str(), &glyphs));
43+
EXPECT_EQ(glyphs.size(), unicode_count)
44+
<< PrintStringVectorWithUnicodes(glyphs);
45+
EXPECT_EQ(target_str, absl::StrJoin(glyphs.begin(), glyphs.end(), ""));
46+
EXPECT_TRUE(NormalizeCleanAndSegmentUTF8(u_mode, OCRNorm::kNone,
47+
GraphemeNormMode::kGlyphSplit, true,
48+
str.c_str(), &glyphs));
49+
EXPECT_EQ(glyphs.size(), glyph_count)
50+
<< PrintStringVectorWithUnicodes(glyphs);
51+
EXPECT_EQ(target_str, absl::StrJoin(glyphs.begin(), glyphs.end(), ""));
52+
EXPECT_TRUE(NormalizeCleanAndSegmentUTF8(u_mode, OCRNorm::kNone,
53+
GraphemeNormMode::kCombined, true,
54+
str.c_str(), &glyphs));
55+
EXPECT_EQ(glyphs.size(), grapheme_count)
56+
<< PrintStringVectorWithUnicodes(glyphs);
57+
EXPECT_EQ(target_str, absl::StrJoin(glyphs.begin(), glyphs.end(), ""));
58+
EXPECT_TRUE(NormalizeCleanAndSegmentUTF8(u_mode, OCRNorm::kNone,
59+
GraphemeNormMode::kSingleString,
60+
true, str.c_str(), &glyphs));
61+
EXPECT_EQ(glyphs.size(), 1) << PrintStringVectorWithUnicodes(glyphs);
62+
EXPECT_EQ(target_str, glyphs[0]);
63+
string result;
64+
EXPECT_TRUE(NormalizeUTF8String(
65+
u_mode, OCRNorm::kNone, GraphemeNorm::kNormalize, str.c_str(), &result));
66+
EXPECT_EQ(target_str, result);
67+
}
68+
69+
} // namespace tesseract
70+
71+
#endif // TESSERACT_UNITTEST_NORMSTRNGS_TEST_H_

0 commit comments

Comments
 (0)