Skip to content

Commit 5b3e2fe

Browse files
committed
Integrated accumulated Symbol Choice in the Choice Iterator and made the api lstm_choice_mode independent
Signed-off-by: Noah Metzger <noah.metzger@bib.uni-mannheim.de>
1 parent bc2b919 commit 5b3e2fe

10 files changed

+174
-91
lines changed

src/api/hocrrenderer.cpp

+18-16
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,17 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) {
213213
}
214214

215215
// Now, process the word...
216-
std::vector<std::vector<std::pair<const char*, float>>>* confidencemap =
216+
std::vector<std::vector<std::pair<const char*, float>>>* rawTimestepMap =
217+
nullptr;
218+
std::vector<std::vector<std::pair<const char*, float>>>* choiceMap =
217219
nullptr;
218220
std::vector<std::vector<std::vector<std::pair<const char*, float>>>>*
219221
symbolMap = nullptr;
220222
if (tesseract_->lstm_choice_mode) {
221-
confidencemap = res_it->GetBestLSTMSymbolChoices();
222-
symbolMap = res_it->GetBestSegmentedLSTMSymbolChoices();
223+
224+
choiceMap = res_it->GetBestLSTMSymbolChoices();
225+
symbolMap = res_it->GetSegmentedLSTMTimesteps();
226+
rawTimestepMap = res_it->GetRawLSTMTimesteps();
223227
}
224228
hocr_str << "\n <span class='ocrx_word'"
225229
<< " id='"
@@ -285,14 +289,14 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) {
285289
if (italic) hocr_str << "</em>";
286290
if (bold) hocr_str << "</strong>";
287291
// If the lstm choice mode is required it is added here
288-
if (tesseract_->lstm_choice_mode == 1 && confidencemap != nullptr) {
289-
for (size_t i = 0; i < confidencemap->size(); i++) {
292+
if (tesseract_->lstm_choice_mode == 1 && rawTimestepMap != nullptr) {
293+
for (size_t i = 0; i < rawTimestepMap->size(); i++) {
290294
hocr_str << "\n <span class='ocrx_cinfo'"
291295
<< " id='"
292296
<< "timestep_" << page_id << "_" << wcnt << "_" << tcnt << "'"
293297
<< ">";
294298
std::vector<std::pair<const char*, float>> timestep =
295-
(*confidencemap)[i];
299+
(*rawTimestepMap)[i];
296300
for (std::pair<const char*, float> conf : timestep) {
297301
hocr_str << "<span class='ocr_glyph'"
298302
<< " id='"
@@ -304,17 +308,16 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) {
304308
hocr_str << "</span>";
305309
tcnt++;
306310
}
307-
} else if (tesseract_->lstm_choice_mode == 2 && confidencemap != nullptr) {
308-
for (size_t i = 0; i < confidencemap->size(); i++) {
311+
} else if (tesseract_->lstm_choice_mode == 2 && choiceMap != nullptr) {
312+
for (size_t i = 0; i < choiceMap->size(); i++) {
309313
std::vector<std::pair<const char*, float>> timestep =
310-
(*confidencemap)[i];
314+
(*choiceMap)[i];
311315
if (timestep.size() > 0) {
312316
hocr_str << "\n <span class='ocrx_cinfo'"
313317
<< " id='"
314318
<< "lstm_choices_" << page_id << "_" << wcnt << "_" << tcnt
315-
<< "'"
316-
<< " chosen='" << timestep[0].first << "'>";
317-
for (size_t j = 1; j < timestep.size(); j++) {
319+
<< "'>";
320+
for (size_t j = 0; j < timestep.size(); j++) {
318321
hocr_str << "<span class='ocr_glyph'"
319322
<< " id='"
320323
<< "choice_" << page_id << "_" << wcnt << "_" << gcnt
@@ -333,10 +336,9 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) {
333336
(*symbolMap)[j];
334337
hocr_str << "\n <span class='ocr_symbol'"
335338
<< " id='"
336-
<< "symbolstep_" << page_id << "_" << wcnt << "_" << scnt
337-
<< "'>"
338-
<< timesteps[0][0].first;
339-
for (size_t i = 1; i < timesteps.size(); i++) {
339+
<< "symbol_" << page_id << "_" << wcnt << "_" << scnt
340+
<< "'>";
341+
for (size_t i = 0; i < timesteps.size(); i++) {
340342
hocr_str << "\n <span class='ocrx_cinfo'"
341343
<< " id='"
342344
<< "timestep_" << page_id << "_" << wcnt << "_" << tcnt

src/ccmain/ltrresultiterator.cpp

+77-23
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,17 @@ bool LTRResultIterator::SymbolIsDropcap() const {
358358
ChoiceIterator::ChoiceIterator(const LTRResultIterator& result_it) {
359359
ASSERT_HOST(result_it.it_->word() != nullptr);
360360
word_res_ = result_it.it_->word();
361+
oemLSTM_ = word_res_->tesseract->AnyLSTMLang();
362+
oemLegacy_ = word_res_->tesseract->AnyTessLang();
361363
BLOB_CHOICE_LIST* choices = nullptr;
364+
tstep_index_ = &result_it.blob_index_;
365+
if (oemLSTM_ && !oemLegacy_ && &word_res_->accumulated_timesteps != nullptr) {
366+
if (word_res_->leadingSpace)
367+
LSTM_choices_ = &word_res_->accumulated_timesteps[(*tstep_index_) + 1];
368+
else
369+
LSTM_choices_ = &word_res_->accumulated_timesteps[*tstep_index_];
370+
filterSpaces();
371+
}
362372
if (word_res_->ratings != nullptr)
363373
choices = word_res_->GetBlobChoices(result_it.blob_index_);
364374
if (choices != nullptr && !choices->empty()) {
@@ -367,49 +377,93 @@ ChoiceIterator::ChoiceIterator(const LTRResultIterator& result_it) {
367377
} else {
368378
choice_it_ = nullptr;
369379
}
370-
if (&word_res_->symbol_steps != nullptr && !word_res_->symbol_steps.empty()) {
371-
symbol_step_it_ = word_res_->symbol_steps.begin();
380+
if (LSTM_choices_ != nullptr && !LSTM_choices_->empty()) {
381+
LSTM_mode_ = true;
382+
LSTM_choice_it_ = LSTM_choices_->begin();
372383
}
373384
}
374-
375385
ChoiceIterator::~ChoiceIterator() { delete choice_it_; }
376386

377387
// Moves to the next choice for the symbol and returns false if there
378388
// are none left.
379389
bool ChoiceIterator::Next() {
380-
if (choice_it_ == nullptr) return false;
381-
if (&word_res_->symbol_steps != nullptr) {
382-
if (symbol_step_it_ == word_res_->symbol_steps.end()) {
383-
symbol_step_it_ = word_res_->symbol_steps.begin();
390+
if (LSTM_mode_) {
391+
if (LSTM_choice_it_ != LSTM_choices_->end() &&
392+
next(LSTM_choice_it_) == LSTM_choices_->end()) {
393+
return false;
384394
} else {
385-
symbol_step_it_++;
386-
}
395+
++LSTM_choice_it_;
396+
return true;
397+
}
398+
} else {
399+
if (choice_it_ == nullptr) return false;
400+
choice_it_->forward();
401+
return !choice_it_->cycled_list();
387402
}
388-
choice_it_->forward();
389-
return !choice_it_->cycled_list();
390403
}
391404

392405
// Returns the null terminated UTF-8 encoded text string for the current
393406
// choice. Do NOT use delete [] to free after use.
394407
const char* ChoiceIterator::GetUTF8Text() const {
395-
if (choice_it_ == nullptr) return nullptr;
396-
UNICHAR_ID id = choice_it_->data()->unichar_id();
397-
return word_res_->uch_set->id_to_unichar_ext(id);
408+
if (LSTM_mode_) {
409+
std::pair<const char*, float> choice = *LSTM_choice_it_;
410+
return choice.first;
411+
} else {
412+
if (choice_it_ == nullptr) return nullptr;
413+
UNICHAR_ID id = choice_it_->data()->unichar_id();
414+
return word_res_->uch_set->id_to_unichar_ext(id);
415+
}
398416
}
399417

400-
// Returns the confidence of the current choice.
401-
// The number should be interpreted as a percent probability. (0.0f-100.0f)
418+
// Returns the confidence of the current choice depending on the used language
419+
// data. If only LSTM traineddata is used the value range is 0.0f - 1.0f. All
420+
// choices for one symbol should roughly add up to 1.0f.
421+
// If only traineddata of the legacy engine is used, the number should be
422+
// interpreted as a percent probability. (0.0f-100.0f) In this case probabilities
423+
// won't add up to 100. Each one stands on its own.
402424
float ChoiceIterator::Confidence() const {
403-
if (choice_it_ == nullptr) return 0.0f;
404-
float confidence = 100 + 5 * choice_it_->data()->certainty();
405-
if (confidence < 0.0f) confidence = 0.0f;
406-
if (confidence > 100.0f) confidence = 100.0f;
407-
return confidence;
425+
if (LSTM_mode_) {
426+
std::pair<const char*, float> choice = *LSTM_choice_it_;
427+
return choice.second;
428+
} else {
429+
if (choice_it_ == nullptr) return 0.0f;
430+
float confidence = 100 + 5 * choice_it_->data()->certainty();
431+
if (confidence < 0.0f) confidence = 0.0f;
432+
if (confidence > 100.0f) confidence = 100.0f;
433+
return confidence;
434+
}
408435
}
409436

437+
// Returns the set of timesteps which belong to the current symbol
410438
std::vector<std::vector<std::pair<const char*, float>>>*
411439
ChoiceIterator::Timesteps() const {
412-
if (&word_res_->symbol_steps == nullptr) return nullptr;
413-
return &*symbol_step_it_;
440+
if (&word_res_->symbol_steps == nullptr || !LSTM_mode_) return nullptr;
441+
if (word_res_->leadingSpace) {
442+
return &word_res_->symbol_steps[*(tstep_index_) + 1];
443+
} else {
444+
return &word_res_->symbol_steps[*tstep_index_];
445+
}
446+
}
447+
448+
void ChoiceIterator::filterSpaces() {
449+
if (LSTM_choices_->empty()) return;
450+
std::vector<std::pair<const char*, float>>::iterator it =
451+
LSTM_choices_->begin();
452+
bool found_space = false;
453+
float sum = 0;
454+
for (it; it != LSTM_choices_->end();) {
455+
if (!strcmp(it->first, " ")) {
456+
it = LSTM_choices_->erase(it);
457+
found_space = true;
458+
} else {
459+
sum += it->second;
460+
++it;
461+
}
462+
}
463+
if (found_space) {
464+
for (it = LSTM_choices_->begin(); it != LSTM_choices_->end(); ++it) {
465+
it->second /= sum;
466+
}
467+
}
414468
}
415469
} // namespace tesseract.

src/ccmain/ltrresultiterator.h

+18-7
Original file line numberDiff line numberDiff line change
@@ -208,25 +208,36 @@ class ChoiceIterator {
208208
// internal structure and should NOT be delete[]ed to free after use.
209209
const char* GetUTF8Text() const;
210210

211-
// Returns the confidence of the current choice.
212-
// The number should be interpreted as a percent probability. (0.0f-100.0f)
211+
// Returns the confidence of the current choice depending on the used language
212+
// data. If only LSTM traineddata is used the value range is 0.0f - 1.0f. All
213+
// choices for one symbol should roughly add up to 1.0f.
214+
// If only traineddata of the legacy engine is used, the number should be
215+
// interpreted as a percent probability. (0.0f-100.0f) In this case
216+
// probabilities won't add up to 100. Each one stands on its own.
213217
float Confidence() const;
214218

215219
// Returns a vector containing all timesteps, which belong to the currently
216220
// selected symbol. A timestep is a vector containing pairs of symbols and
217221
// floating point numbers. The number states the probability for the
218222
// corresponding symbol.
219-
std::vector<std::vector<std::pair<const char*, float>>>*
220-
Timesteps() const;
223+
std::vector<std::vector<std::pair<const char*, float>>>* Timesteps() const;
221224

222225
private:
226+
//clears the remaining spaces out of the results and adapt the probabilities
227+
void filterSpaces();
223228
// Pointer to the WERD_RES object owned by the API.
224229
WERD_RES* word_res_;
225230
// Iterator over the blob choices.
226231
BLOB_CHOICE_IT* choice_it_;
227-
//Iterator over the symbol steps.
228-
std::vector<std::vector<std::vector<std::pair<const char*, float>>>>::iterator
229-
symbol_step_it_;
232+
std::vector<std::pair<const char*, float>>* LSTM_choices_ = nullptr;
233+
std::vector<std::pair<const char*, float>>::iterator LSTM_choice_it_;
234+
235+
const int* tstep_index_;
236+
bool LSTM_mode_ = false;
237+
//true when there is lstm engine related trained data
238+
bool oemLSTM_;
239+
// true when there is legacy engine related trained data
240+
bool oemLegacy_;
230241
};
231242

232243
} // namespace tesseract.

src/ccmain/resultiterator.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -604,18 +604,26 @@ char* ResultIterator::GetUTF8Text(PageIteratorLevel level) const {
604604
strncpy(result, text.string(), length);
605605
return result;
606606
}
607+
std::vector<std::vector<std::pair<const char*, float>>>*
608+
ResultIterator::GetRawLSTMTimesteps() const {
609+
if (it_->word() != nullptr) {
610+
return &it_->word()->raw_timesteps;
611+
} else {
612+
return nullptr;
613+
}
614+
}
607615

608616
std::vector<std::vector<std::pair<const char*, float>>>*
609617
ResultIterator::GetBestLSTMSymbolChoices() const {
610618
if (it_->word() != nullptr) {
611-
return &it_->word()->timesteps;
619+
return &it_->word()->accumulated_timesteps;
612620
} else {
613621
return nullptr;
614622
}
615623
}
616624

617625
std::vector<std::vector<std::vector<std::pair<const char*, float>>>>*
618-
ResultIterator::GetBestSegmentedLSTMSymbolChoices() const {
626+
ResultIterator::GetSegmentedLSTMTimesteps() const {
619627
if (it_->word() != nullptr) {
620628
return &it_->word()->symbol_steps;
621629
} else {

src/ccmain/resultiterator.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,12 @@ class TESS_API ResultIterator : public LTRResultIterator {
100100
/**
101101
* Returns the LSTM choices for every LSTM timestep for the current word.
102102
*/
103+
virtual std::vector<std::vector<std::pair<const char*, float>>>*
104+
GetRawLSTMTimesteps() const;
103105
virtual std::vector<std::vector<std::pair<const char*, float>>>*
104106
GetBestLSTMSymbolChoices() const;
105107
virtual std::vector<std::vector<std::vector<std::pair<const char*, float>>>>*
106-
GetBestSegmentedLSTMSymbolChoices() const;
108+
GetSegmentedLSTMTimesteps() const;
107109

108110
/**
109111
* Return whether the current paragraph's dominant reading direction

src/ccmain/tesseractclass.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -524,11 +524,12 @@ Tesseract::Tesseract()
524524
this->params()),
525525
INT_MEMBER(lstm_choice_mode, 0,
526526
"Allows to include alternative symbols choices in the hOCR output. "
527-
"Valid input values are 0, 1 and 2. 0 is the default value. "
527+
"Valid input values are 0, 1, 2 and 3. 0 is the default value. "
528528
"With 1 the alternative symbol choices per timestep are included. "
529-
"With 2 the alternative symbol choices are accumulated per character."
530-
"With 3 the alternative symbol choices per timestep are included and "
531-
"separated by the suggested segmentation of Tesseract",
529+
"With 2 the alternative symbol choices are accumulated per "
530+
"character. "
531+
"With 3 the alternative symbol choices per timestep are included "
532+
"and separated by the suggested segmentation of Tesseract",
532533
this->params()),
533534

534535
backup_config_file_(nullptr),

src/ccmain/tesseractclass.h

+7-5
Original file line numberDiff line numberDiff line change
@@ -1124,12 +1124,14 @@ class Tesseract : public Wordrec {
11241124
STRING_VAR_H(page_separator, "\f",
11251125
"Page separator (default is form feed control character)");
11261126
INT_VAR_H(lstm_choice_mode, 0,
1127-
"Allows to include alternative symbols choices in the hOCR output. "
1128-
"Valid input values are 0, 1 and 2. 0 is the default value. "
1127+
"Allows to include alternative symbols choices in the hOCR "
1128+
"output. "
1129+
"Valid input values are 0, 1, 2 and 3. 0 is the default value. "
11291130
"With 1 the alternative symbol choices per timestep are included. "
1130-
"With 2 the alternative symbol choices are accumulated per character."
1131-
"With 3 the alternative symbol choices per timestep are included and "
1132-
"separated by the suggested segmentation of Tesseract");
1131+
"With 2 the alternative symbol choices are accumulated per "
1132+
"character. "
1133+
"With 3 the alternative symbol choices per timestep are included "
1134+
"and separated by the suggested segmentation of Tesseract");
11331135

11341136
//// ambigsrecog.cpp /////////////////////////////////////////////////////////
11351137
FILE *init_recog_training(const STRING &fname);

src/ccstruct/pageres.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,12 @@ class WERD_RES : public ELIST_LINK {
221221
// blob i and blob i+1.
222222
GenericVector<int> blob_gaps;
223223
// Stores the lstm choices of every timestep
224-
std::vector<std::vector<std::pair<const char*, float>>> timesteps;
224+
std::vector<std::vector<std::pair<const char*, float>>> raw_timesteps;
225+
std::vector<std::vector<std::pair<const char*, float>>> accumulated_timesteps;
225226
std::vector<std::vector<std::vector<std::pair<const char*, float>>>>
226227
symbol_steps;
228+
//Stores if the timestep vector starts with a space
229+
bool leadingSpace = false;
227230
// Ratings matrix contains classifier choices for each classified combination
228231
// of blobs. The dimension is the same as the number of blobs in chopped_word
229232
// and the leading diagonal corresponds to classifier results of the blobs

0 commit comments

Comments
 (0)