Skip to content

Commit ce76d1c

Browse files
committed
Fixes to training process to allow incremental training from a recognition model
1 parent 9d90567 commit ce76d1c

31 files changed

+650
-122
lines changed

ccmain/linerec.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ void Tesseract::TrainLineRecognizer(const STRING& input_imagename,
6464
return;
6565
}
6666
TrainFromBoxes(boxes, texts, block_list, &images);
67+
images.Shuffle();
6768
if (!images.SaveDocument(lstmf_name.string(), NULL)) {
6869
tprintf("Failed to write training data to %s!\n", lstmf_name.string());
6970
}
@@ -79,7 +80,10 @@ void Tesseract::TrainFromBoxes(const GenericVector<TBOX>& boxes,
7980
int box_count = boxes.size();
8081
// Process all the text lines in this page, as defined by the boxes.
8182
int end_box = 0;
82-
for (int start_box = 0; start_box < box_count; start_box = end_box) {
83+
// Don't let \t, which marks newlines in the box file, get into the line
84+
// content, as that makes the line unusable in training.
85+
while (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
86+
for (int start_box = end_box; start_box < box_count; start_box = end_box) {
8387
// Find the textline of boxes starting at start and their bounding box.
8488
TBOX line_box = boxes[start_box];
8589
STRING line_str = texts[start_box];
@@ -115,7 +119,9 @@ void Tesseract::TrainFromBoxes(const GenericVector<TBOX>& boxes,
115119
}
116120
if (imagedata != NULL)
117121
training_data->AddPageToDocument(imagedata);
118-
if (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
122+
// Don't let \t, which marks newlines in the box file, get into the line
123+
// content, as that makes the line unusable in training.
124+
while (end_box < texts.size() && texts[end_box] == "\t") ++end_box;
119125
}
120126
}
121127

ccstruct/boxread.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ bool ReadAllBoxes(int target_page, bool skip_blanks, const STRING& filename,
5555
GenericVector<char> box_data;
5656
if (!tesseract::LoadDataFromFile(BoxFileName(filename), &box_data))
5757
return false;
58+
// Convert the array of bytes to a string, so it can be used by the parser.
59+
box_data.push_back('\0');
5860
return ReadMemBoxes(target_page, skip_blanks, &box_data[0], boxes, texts,
5961
box_texts, pages);
6062
}

ccstruct/imagedata.cpp

+21-6
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,18 @@
2424

2525
#include "imagedata.h"
2626

27+
#if defined(__MINGW32__)
28+
#include <unistd.h>
29+
#else
30+
#include <thread>
31+
#endif
32+
2733
#include "allheaders.h"
2834
#include "boxread.h"
2935
#include "callcpp.h"
3036
#include "helpers.h"
3137
#include "tprintf.h"
3238

33-
#if defined(__MINGW32__)
34-
# include <unistd.h>
35-
#else
36-
# include <thread>
37-
#endif
38-
3939
// Number of documents to read ahead while training. Doesn't need to be very
4040
// large.
4141
const int kMaxReadAhead = 8;
@@ -496,6 +496,21 @@ inT64 DocumentData::UnCache() {
496496
return memory_saved;
497497
}
498498

499+
// Shuffles all the pages in the document.
500+
void DocumentData::Shuffle() {
501+
TRand random;
502+
// Different documents get shuffled differently, but the same for the same
503+
// name.
504+
random.set_seed(document_name_.string());
505+
int num_pages = pages_.size();
506+
// Execute one random swap for each page in the document.
507+
for (int i = 0; i < num_pages; ++i) {
508+
int src = random.IntRand() % num_pages;
509+
int dest = random.IntRand() % num_pages;
510+
std::swap(pages_[src], pages_[dest]);
511+
}
512+
}
513+
499514
// Locks the pages_mutex_ and Loads as many pages can fit in max_memory_
500515
// starting at index pages_offset_.
501516
bool DocumentData::ReCachePages() {

ccstruct/imagedata.h

+2
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,8 @@ class DocumentData {
266266
// Removes all pages from memory and frees the memory, but does not forget
267267
// the document metadata. Returns the memory saved.
268268
inT64 UnCache();
269+
// Shuffles all the pages in the document.
270+
void Shuffle();
269271

270272
private:
271273
// Sets the value of total_pages_ behind a mutex.

ccstruct/pageres.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -529,13 +529,12 @@ void WERD_RES::FilterWordChoices(int debug_level) {
529529
if (choice->unichar_id(i) != best_choice->unichar_id(j) &&
530530
choice->certainty(i) - best_choice->certainty(j) < threshold) {
531531
if (debug_level >= 2) {
532-
STRING label;
533-
label.add_str_int("\nDiscarding bad choice #", index);
534-
choice->print(label.string());
535-
tprintf("i %d j %d Chunk %d Choice->Blob[i].Certainty %.4g"
536-
" BestChoice->ChunkCertainty[Chunk] %g Threshold %g\n",
537-
i, j, chunk, choice->certainty(i),
538-
best_choice->certainty(j), threshold);
532+
choice->print("WorstCertaintyDiffWorseThan");
533+
tprintf(
534+
"i %d j %d Choice->Blob[i].Certainty %.4g"
535+
" WorstOtherChoiceCertainty %g Threshold %g\n",
536+
i, j, choice->certainty(i), best_choice->certainty(j), threshold);
537+
tprintf("Discarding bad choice #%d\n", index);
539538
}
540539
delete it.extract();
541540
break;

ccutil/genericvector.h

+12-2
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,7 @@ inline bool LoadDataFromFile(const STRING& filename,
363363
fseek(fp, 0, SEEK_END);
364364
size_t size = ftell(fp);
365365
fseek(fp, 0, SEEK_SET);
366-
// Pad with a 0, just in case we treat the result as a string.
367-
data->init_to_size(static_cast<int>(size) + 1, 0);
366+
data->init_to_size(static_cast<int>(size), 0);
368367
bool result = fread(&(*data)[0], 1, size, fp) == size;
369368
fclose(fp);
370369
return result;
@@ -380,6 +379,17 @@ inline bool SaveDataToFile(const GenericVector<char>& data,
380379
fclose(fp);
381380
return result;
382381
}
382+
// Reads a file as a vector of STRING.
383+
inline bool LoadFileLinesToStrings(const STRING& filename,
384+
GenericVector<STRING>* lines) {
385+
GenericVector<char> data;
386+
if (!LoadDataFromFile(filename.string(), &data)) {
387+
return false;
388+
}
389+
STRING lines_str(&data[0], data.size());
390+
lines_str.split('\n', lines);
391+
return true;
392+
}
383393

384394
template <typename T>
385395
bool cmp_eq(T const & t1, T const & t2) {

ccutil/helpers.h

+7
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
#include <stdio.h>
2929
#include <string.h>
30+
#include <functional>
31+
#include <string>
3032

3133
#include "host.h"
3234

@@ -43,6 +45,11 @@ class TRand {
4345
void set_seed(uinT64 seed) {
4446
seed_ = seed;
4547
}
48+
// Sets the seed using a hash of a string.
49+
void set_seed(const std::string& str) {
50+
std::hash<std::string> hasher;
51+
set_seed(static_cast<uinT64>(hasher(str)));
52+
}
4653

4754
// Returns an integer in the range 0 to MAX_INT32.
4855
inT32 IntRand() {

lstm/fullyconnected.cpp

+17-6
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ StaticShape FullyConnected::OutputShape(const StaticShape& input_shape) const {
5656
return result;
5757
}
5858

59+
// Suspends/Enables training by setting the training_ flag. Serialize and
60+
// DeSerialize only operate on the run-time data if state is false.
61+
void FullyConnected::SetEnableTraining(TrainingState state) {
62+
if (state == TS_RE_ENABLE) {
63+
if (training_ == TS_DISABLED) weights_.InitBackward(false);
64+
training_ = TS_ENABLED;
65+
} else {
66+
training_ = state;
67+
}
68+
}
69+
5970
// Sets up the network for training. Initializes weights using weights of
6071
// scale `range` picked according to the random number generator `randomizer`.
6172
int FullyConnected::InitWeights(float range, TRand* randomizer) {
@@ -78,14 +89,14 @@ void FullyConnected::DebugWeights() {
7889
// Writes to the given file. Returns false in case of error.
7990
bool FullyConnected::Serialize(TFile* fp) const {
8091
if (!Network::Serialize(fp)) return false;
81-
if (!weights_.Serialize(training_, fp)) return false;
92+
if (!weights_.Serialize(IsTraining(), fp)) return false;
8293
return true;
8394
}
8495

8596
// Reads from the given file. Returns false in case of error.
8697
// If swap is true, assumes a big/little-endian swap is needed.
8798
bool FullyConnected::DeSerialize(bool swap, TFile* fp) {
88-
if (!weights_.DeSerialize(training_, swap, fp)) return false;
99+
if (!weights_.DeSerialize(IsTraining(), swap, fp)) return false;
89100
return true;
90101
}
91102

@@ -129,14 +140,14 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
129140
}
130141
ForwardTimeStep(d_input, i_input, t, temp_line);
131142
output->WriteTimeStep(t, temp_line);
132-
if (training() && type_ != NT_SOFTMAX) {
143+
if (IsTraining() && type_ != NT_SOFTMAX) {
133144
acts_.CopyTimeStepFrom(t, *output, t);
134145
}
135146
}
136147
// Zero all the elements that are in the padding around images that allows
137148
// multiple different-sized images to exist in a single array.
138149
// acts_ is only used if this is not a softmax op.
139-
if (training() && type_ != NT_SOFTMAX) {
150+
if (IsTraining() && type_ != NT_SOFTMAX) {
140151
acts_.ZeroInvalidElements();
141152
}
142153
output->ZeroInvalidElements();
@@ -152,7 +163,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
152163
const TransposedArray* input_transpose) {
153164
// Softmax output is always float, so save the input type.
154165
int_mode_ = input.int_mode();
155-
if (training()) {
166+
if (IsTraining()) {
156167
acts_.Resize(input, no_);
157168
// Source_ is a transposed copy of input. It isn't needed if provided.
158169
external_source_ = input_transpose;
@@ -163,7 +174,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
163174
void FullyConnected::ForwardTimeStep(const double* d_input, const inT8* i_input,
164175
int t, double* output_line) {
165176
// input is copied to source_ line-by-line for cache coherency.
166-
if (training() && external_source_ == NULL && d_input != NULL)
177+
if (IsTraining() && external_source_ == NULL && d_input != NULL)
167178
source_t_.WriteStrided(t, d_input);
168179
if (d_input != NULL)
169180
weights_.MatrixDotVector(d_input, output_line);

lstm/fullyconnected.h

+4
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class FullyConnected : public Network {
6161
type_ = type;
6262
}
6363

64+
// Suspends/Enables training by setting the training_ flag. Serialize and
65+
// DeSerialize only operate on the run-time data if state is false.
66+
virtual void SetEnableTraining(TrainingState state);
67+
6468
// Sets up the network for training. Initializes weights using weights of
6569
// scale `range` picked according to the random number generator `randomizer`.
6670
virtual int InitWeights(float range, TRand* randomizer);

lstm/lstm.cpp

+22-5
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,23 @@ StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
102102
return result;
103103
}
104104

105+
// Suspends/Enables training by setting the training_ flag. Serialize and
106+
// DeSerialize only operate on the run-time data if state is false.
107+
void LSTM::SetEnableTraining(TrainingState state) {
108+
if (state == TS_RE_ENABLE) {
109+
if (training_ == TS_DISABLED) {
110+
for (int w = 0; w < WT_COUNT; ++w) {
111+
if (w == GFS && !Is2D()) continue;
112+
gate_weights_[w].InitBackward(false);
113+
}
114+
}
115+
training_ = TS_ENABLED;
116+
} else {
117+
training_ = state;
118+
}
119+
if (softmax_ != NULL) softmax_->SetEnableTraining(state);
120+
}
121+
105122
// Sets up the network for training. Initializes weights using weights of
106123
// scale `range` picked according to the random number generator `randomizer`.
107124
int LSTM::InitWeights(float range, TRand* randomizer) {
@@ -148,7 +165,7 @@ bool LSTM::Serialize(TFile* fp) const {
148165
if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
149166
for (int w = 0; w < WT_COUNT; ++w) {
150167
if (w == GFS && !Is2D()) continue;
151-
if (!gate_weights_[w].Serialize(training_, fp)) return false;
168+
if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
152169
}
153170
if (softmax_ != NULL && !softmax_->Serialize(fp)) return false;
154171
return true;
@@ -169,7 +186,7 @@ bool LSTM::DeSerialize(bool swap, TFile* fp) {
169186
is_2d_ = false;
170187
for (int w = 0; w < WT_COUNT; ++w) {
171188
if (w == GFS && !Is2D()) continue;
172-
if (!gate_weights_[w].DeSerialize(training_, swap, fp)) return false;
189+
if (!gate_weights_[w].DeSerialize(IsTraining(), swap, fp)) return false;
173190
if (w == CI) {
174191
ns_ = gate_weights_[CI].NumOutputs();
175192
is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
@@ -322,7 +339,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
322339
MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
323340
// Clip curr_state to a sane range.
324341
ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
325-
if (training_) {
342+
if (IsTraining()) {
326343
// Save the gate node values.
327344
node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
328345
node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
@@ -331,7 +348,7 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
331348
if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
332349
}
333350
FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
334-
if (training_) state_.WriteTimeStep(t, curr_state);
351+
if (IsTraining()) state_.WriteTimeStep(t, curr_state);
335352
if (softmax_ != NULL) {
336353
if (input.int_mode()) {
337354
int_output->WriteTimeStep(0, curr_output);
@@ -697,7 +714,7 @@ void LSTM::PrintDW() {
697714
void LSTM::ResizeForward(const NetworkIO& input) {
698715
source_.Resize(input, na_);
699716
which_fg_.ResizeNoInit(input.Width(), ns_);
700-
if (training_) {
717+
if (IsTraining()) {
701718
state_.ResizeFloat(input, ns_);
702719
for (int w = 0; w < WT_COUNT; ++w) {
703720
if (w == GFS && !Is2D()) continue;

lstm/lstm.h

+4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class LSTM : public Network {
6969
return spec;
7070
}
7171

72+
// Suspends/Enables training by setting the training_ flag. Serialize and
73+
// DeSerialize only operate on the run-time data if state is false.
74+
virtual void SetEnableTraining(TrainingState state);
75+
7276
// Sets up the network for training. Initializes weights using weights of
7377
// scale `range` picked according to the random number generator `randomizer`.
7478
virtual int InitWeights(float range, TRand* randomizer);

lstm/lstmrecognizer.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
253253
float label_threshold, float* scale_factor,
254254
NetworkIO* inputs, NetworkIO* outputs) {
255255
// Maximum width of image to train on.
256-
const int kMaxImageWidth = 2048;
256+
const int kMaxImageWidth = 2560;
257257
// This ensures consistent recognition results.
258258
SetRandomSeed();
259259
int min_width = network_->XScaleFactor();
@@ -263,7 +263,7 @@ bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
263263
tprintf("Line cannot be recognized!!\n");
264264
return false;
265265
}
266-
if (network_->training() && pixGetWidth(pix) > kMaxImageWidth) {
266+
if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
267267
tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix),
268268
pixGetHeight(pix));
269269
pixDestroy(&pix);

0 commit comments

Comments
 (0)