22
22
#include " networkio.h"
23
23
#include " pageres.h"
24
24
#include " unicharcompress.h"
25
+ #include < set>
26
+ #include < vector>
25
27
26
28
#include < algorithm>
27
29
@@ -77,13 +79,18 @@ RecodeBeamSearch::RecodeBeamSearch(const UnicharCompress& recoder,
77
79
// Decodes the set of network outputs, storing the lattice internally.
78
80
void RecodeBeamSearch::Decode (const NetworkIO& output, double dict_ratio,
79
81
double cert_offset, double worst_dict_cert,
80
- const UNICHARSET* charset) {
82
+ const UNICHARSET* charset, bool glyph_confidence ) {
81
83
beam_size_ = 0 ;
82
84
int width = output.Width ();
85
+ if (glyph_confidence)
86
+ timesteps.clear ();
83
87
for (int t = 0 ; t < width; ++t) {
84
88
ComputeTopN (output.f (t), output.NumFeatures (), kBeamWidths [0 ]);
85
89
DecodeStep (output.f (t), t, dict_ratio, cert_offset, worst_dict_cert,
86
90
charset);
91
+ if (glyph_confidence) {
92
+ SaveMostCertainGlyphs (output.f (t), output.NumFeatures (), charset, t);
93
+ }
87
94
}
88
95
}
89
96
void RecodeBeamSearch::Decode (const GENERIC_2D_ARRAY<float >& output,
@@ -98,6 +105,35 @@ void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float>& output,
98
105
}
99
106
}
100
107
108
+ void RecodeBeamSearch::SaveMostCertainGlyphs (const float * outputs,
109
+ int num_outputs,
110
+ const UNICHARSET* charset,
111
+ int xCoord) {
112
+ std::vector<std::pair<const char *, float >> glyphs;
113
+ int pos = 0 ;
114
+ for (int i = 0 ; i < num_outputs; ++i) {
115
+ if (outputs[i] >= 0 .01f ) {
116
+ const char * charakter;
117
+ if (i + 2 >= num_outputs) {
118
+ charakter = " " ;
119
+ } else if (i > 0 ) {
120
+ charakter = charset->id_to_unichar_ext (i + 2 );
121
+ } else {
122
+ charakter = charset->id_to_unichar_ext (i);
123
+ }
124
+ pos = 0 ;
125
+ // order the possible glyphs within one timestep
126
+ // beginning with the most likely
127
+ while (glyphs.size () > pos && glyphs[pos].second > outputs[i]) {
128
+ pos++;
129
+ }
130
+ glyphs.insert (glyphs.begin () + pos,
131
+ std::pair<const char *, float >(charakter, outputs[i]));
132
+ }
133
+ }
134
+ timesteps.push_back (glyphs);
135
+ }
136
+
101
137
// Returns the best path as labels/scores/xcoords similar to simple CTC.
102
138
void RecodeBeamSearch::ExtractBestPathAsLabels (
103
139
GenericVector<int >* labels, GenericVector<int >* xcoords) const {
@@ -140,7 +176,8 @@ void RecodeBeamSearch::ExtractBestPathAsUnicharIds(
140
176
void RecodeBeamSearch::ExtractBestPathAsWords (const TBOX& line_box,
141
177
float scale_factor, bool debug,
142
178
const UNICHARSET* unicharset,
143
- PointerVector<WERD_RES>* words) {
179
+ PointerVector<WERD_RES>* words,
180
+ bool glyph_confidence) {
144
181
words->truncate (0 );
145
182
GenericVector<int > unichar_ids;
146
183
GenericVector<float > certs;
@@ -165,6 +202,7 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box,
165
202
}
166
203
// Convert labels to unichar-ids.
167
204
int word_end = 0 ;
205
+ int timestepEnd = 0 ;
168
206
float prev_space_cert = 0 .0f ;
169
207
for (int word_start = 0 ; word_start < num_ids; word_start = word_end) {
170
208
for (word_end = word_start + 1 ; word_end < num_ids; ++word_end) {
@@ -188,6 +226,12 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box,
188
226
WERD_RES* word_res = InitializeWord (
189
227
leading_space, line_box, word_start, word_end,
190
228
std::min (space_cert, prev_space_cert), unicharset, xcoords, scale_factor);
229
+ if (glyph_confidence) {
230
+ for (size_t i = timestepEnd; i < xcoords[word_end]; i++) {
231
+ word_res->timesteps .push_back (timesteps[i]);
232
+ }
233
+ timestepEnd = xcoords[word_end];
234
+ }
191
235
for (int i = word_start; i < word_end; ++i) {
192
236
BLOB_CHOICE_LIST* choices = new BLOB_CHOICE_LIST;
193
237
BLOB_CHOICE_IT bc_it (choices);
@@ -381,7 +425,7 @@ void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs,
381
425
void RecodeBeamSearch::DecodeStep (const float * outputs, int t,
382
426
double dict_ratio, double cert_offset,
383
427
double worst_dict_cert,
384
- const UNICHARSET* charset) {
428
+ const UNICHARSET* charset, bool debug ) {
385
429
if (t == beam_.size ()) beam_.push_back (new RecodeBeam);
386
430
RecodeBeam* step = beam_[t];
387
431
beam_size_ = t + 1 ;
@@ -396,7 +440,7 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
396
440
}
397
441
} else {
398
442
RecodeBeam* prev = beam_[t - 1 ];
399
- if (charset != nullptr ) {
443
+ if (debug ) {
400
444
int beam_index = BeamIndex (true , NC_ANYTHING, 0 );
401
445
for (int i = prev->beams_ [beam_index].size () - 1 ; i >= 0 ; --i) {
402
446
GenericVector<const RecodeNode*> path;
0 commit comments