Skip to content

Commit 2633fef

Browse files
committed
Part 2 of separating out the unicharset from the LSTM model, fixing command line for training
1 parent 61adbdf commit 2633fef

19 files changed

+625
-222
lines changed

dict/dawg.cpp

+15-14
Original file line numberDiff line numberDiff line change
@@ -339,16 +339,15 @@ bool SquishedDawg::read_squished_dawg(TFile *file) {
339339
return true;
340340
}
341341

342-
NODE_MAP SquishedDawg::build_node_map(inT32 *num_nodes) const {
342+
std::unique_ptr<EDGE_REF[]> SquishedDawg::build_node_map(
343+
inT32 *num_nodes) const {
343344
EDGE_REF edge;
344-
NODE_MAP node_map;
345+
std::unique_ptr<EDGE_REF[]> node_map(new EDGE_REF[num_edges_]);
345346
inT32 node_counter;
346347
inT32 num_edges;
347348

348-
node_map = (NODE_MAP) malloc(sizeof(EDGE_REF) * num_edges_);
349-
350349
for (edge = 0; edge < num_edges_; edge++) // init all slots
351-
node_map [edge] = -1;
350+
node_map[edge] = -1;
352351

353352
node_counter = num_forward_edges(0);
354353

@@ -366,33 +365,34 @@ NODE_MAP SquishedDawg::build_node_map(inT32 *num_nodes) const {
366365
edge--;
367366
}
368367
}
369-
return (node_map);
368+
return node_map;
370369
}
371370

372-
void SquishedDawg::write_squished_dawg(FILE *file) {
371+
bool SquishedDawg::write_squished_dawg(TFile *file) {
373372
EDGE_REF edge;
374373
inT32 num_edges;
375374
inT32 node_count = 0;
376-
NODE_MAP node_map;
377375
EDGE_REF old_index;
378376
EDGE_RECORD temp_record;
379377

380378
if (debug_level_) tprintf("write_squished_dawg\n");
381379

382-
node_map = build_node_map(&node_count);
380+
std::unique_ptr<EDGE_REF[]> node_map(build_node_map(&node_count));
383381

384382
// Write the magic number to help detecting a change in endianness.
385383
inT16 magic = kDawgMagicNumber;
386-
fwrite(&magic, sizeof(inT16), 1, file);
387-
fwrite(&unicharset_size_, sizeof(inT32), 1, file);
384+
if (file->FWrite(&magic, sizeof(magic), 1) != 1) return false;
385+
if (file->FWrite(&unicharset_size_, sizeof(unicharset_size_), 1) != 1)
386+
return false;
388387

389388
// Count the number of edges in this Dawg.
390389
num_edges = 0;
391390
for (edge=0; edge < num_edges_; edge++)
392391
if (forward_edge(edge))
393392
num_edges++;
394393

395-
fwrite(&num_edges, sizeof(inT32), 1, file); // write edge count to file
394+
// Write edge count to file.
395+
if (file->FWrite(&num_edges, sizeof(num_edges), 1) != 1) return false;
396396

397397
if (debug_level_) {
398398
tprintf("%d nodes in DAWG\n", node_count);
@@ -405,7 +405,8 @@ void SquishedDawg::write_squished_dawg(FILE *file) {
405405
old_index = next_node_from_edge_rec(edges_[edge]);
406406
set_next_node(edge, node_map[old_index]);
407407
temp_record = edges_[edge];
408-
fwrite(&(temp_record), sizeof(EDGE_RECORD), 1, file);
408+
if (file->FWrite(&temp_record, sizeof(temp_record), 1) != 1)
409+
return false;
409410
set_next_node(edge, old_index);
410411
} while (!last_edge(edge++));
411412

@@ -416,7 +417,7 @@ void SquishedDawg::write_squished_dawg(FILE *file) {
416417
edge--;
417418
}
418419
}
419-
free(node_map);
420+
return true;
420421
}
421422

422423
} // namespace tesseract

dict/dawg.h

+15-11
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@
3131
I n c l u d e s
3232
----------------------------------------------------------------------*/
3333

34+
#include <memory>
3435
#include "elst.h"
35-
#include "ratngs.h"
3636
#include "params.h"
37+
#include "ratngs.h"
3738
#include "tesscallback.h"
3839

3940
#ifndef __GNUC__
@@ -483,18 +484,22 @@ class SquishedDawg : public Dawg {
483484
void print_node(NODE_REF node, int max_num_edges) const;
484485

485486
/// Writes the squished/reduced Dawg to a file.
486-
void write_squished_dawg(FILE *file);
487+
bool write_squished_dawg(TFile *file);
487488

488489
/// Opens the file with the given filename and writes the
489490
/// squished/reduced Dawg to the file.
490-
void write_squished_dawg(const char *filename) {
491-
FILE *file = fopen(filename, "wb");
492-
if (file == NULL) {
493-
tprintf("Error opening %s\n", filename);
494-
exit(1);
491+
bool write_squished_dawg(const char *filename) {
492+
TFile file;
493+
file.OpenWrite(nullptr);
494+
if (!this->write_squished_dawg(&file)) {
495+
tprintf("Error serializing %s\n", filename);
496+
return false;
495497
}
496-
this->write_squished_dawg(file);
497-
fclose(file);
498+
if (!file.CloseWrite(filename, nullptr)) {
499+
tprintf("Error writing file %s\n", filename);
500+
return false;
501+
}
502+
return true;
498503
}
499504

500505
private:
@@ -549,8 +554,7 @@ class SquishedDawg : public Dawg {
549554
tprintf("__________________________\n");
550555
}
551556
/// Constructs a mapping from the memory node indices to disk node indices.
552-
NODE_MAP build_node_map(inT32 *num_nodes) const;
553-
557+
std::unique_ptr<EDGE_REF[]> build_node_map(inT32 *num_nodes) const;
554558

555559
// Member variables.
556560
EDGE_ARRAY edges_;

dict/trie.cpp

+17-22
Original file line numberDiff line numberDiff line change
@@ -290,51 +290,46 @@ bool Trie::read_and_add_word_list(const char *filename,
290290
const UNICHARSET &unicharset,
291291
Trie::RTLReversePolicy reverse_policy) {
292292
GenericVector<STRING> word_list;
293-
if (!read_word_list(filename, unicharset, reverse_policy, &word_list))
294-
return false;
293+
if (!read_word_list(filename, &word_list)) return false;
295294
word_list.sort(sort_strings_by_dec_length);
296-
return add_word_list(word_list, unicharset);
295+
return add_word_list(word_list, unicharset, reverse_policy);
297296
}
298297

299298
bool Trie::read_word_list(const char *filename,
300-
const UNICHARSET &unicharset,
301-
Trie::RTLReversePolicy reverse_policy,
302299
GenericVector<STRING>* words) {
303300
FILE *word_file;
304-
char string[CHARS_PER_LINE];
301+
char line_str[CHARS_PER_LINE];
305302
int word_count = 0;
306303

307304
word_file = fopen(filename, "rb");
308305
if (word_file == NULL) return false;
309306

310-
while (fgets(string, CHARS_PER_LINE, word_file) != NULL) {
311-
chomp_string(string); // remove newline
312-
WERD_CHOICE word(string, unicharset);
313-
if ((reverse_policy == RRP_REVERSE_IF_HAS_RTL &&
314-
word.has_rtl_unichar_id()) ||
315-
reverse_policy == RRP_FORCE_REVERSE) {
316-
word.reverse_and_mirror_unichar_ids();
317-
}
307+
while (fgets(line_str, sizeof(line_str), word_file) != NULL) {
308+
chomp_string(line_str); // remove newline
309+
STRING word_str(line_str);
318310
++word_count;
319311
if (debug_level_ && word_count % 10000 == 0)
320312
tprintf("Read %d words so far\n", word_count);
321-
if (word.length() != 0 && !word.contains_unichar_id(INVALID_UNICHAR_ID)) {
322-
words->push_back(word.unichar_string());
323-
} else if (debug_level_) {
324-
tprintf("Skipping invalid word %s\n", string);
325-
if (debug_level_ >= 3) word.print();
326-
}
313+
words->push_back(word_str);
327314
}
328315
if (debug_level_)
329316
tprintf("Read %d words total.\n", word_count);
330317
fclose(word_file);
331318
return true;
332319
}
333320

334-
bool Trie::add_word_list(const GenericVector<STRING>& words,
335-
const UNICHARSET &unicharset) {
321+
bool Trie::add_word_list(const GenericVector<STRING> &words,
322+
const UNICHARSET &unicharset,
323+
Trie::RTLReversePolicy reverse_policy) {
336324
for (int i = 0; i < words.size(); ++i) {
337325
WERD_CHOICE word(words[i].string(), unicharset);
326+
if (word.length() == 0 || word.contains_unichar_id(INVALID_UNICHAR_ID))
327+
continue;
328+
if ((reverse_policy == RRP_REVERSE_IF_HAS_RTL &&
329+
word.has_rtl_unichar_id()) ||
330+
reverse_policy == RRP_FORCE_REVERSE) {
331+
word.reverse_and_mirror_unichar_ids();
332+
}
338333
if (!word_in_dawg(word)) {
339334
add_word_to_dawg(word);
340335
if (!word_in_dawg(word)) {

dict/trie.h

+5-7
Original file line numberDiff line numberDiff line change
@@ -177,18 +177,16 @@ class Trie : public Dawg {
177177
const UNICHARSET &unicharset,
178178
Trie::RTLReversePolicy reverse);
179179

180-
// Reads a list of words from the given file, applying the reverse_policy,
181-
// according to information in the unicharset.
180+
// Reads a list of words from the given file.
182181
// Returns false on error.
183182
bool read_word_list(const char *filename,
184-
const UNICHARSET &unicharset,
185-
Trie::RTLReversePolicy reverse_policy,
186183
GenericVector<STRING>* words);
187184
// Adds a list of words previously read using read_word_list to the trie
188-
// using the given unicharset to convert to unichar-ids.
185+
// using the given unicharset and reverse_policy to convert to unichar-ids.
189186
// Returns false on error.
190-
bool add_word_list(const GenericVector<STRING>& words,
191-
const UNICHARSET &unicharset);
187+
bool add_word_list(const GenericVector<STRING> &words,
188+
const UNICHARSET &unicharset,
189+
Trie::RTLReversePolicy reverse_policy);
192190

193191
// Inserts the list of patterns from the given file into the Trie.
194192
// The pattern list file should contain one pattern per line in UTF-8 format.

lstm/lstmtrainer.cpp

+12-64
Original file line numberDiff line numberDiff line change
@@ -130,22 +130,6 @@ bool LSTMTrainer::TryLoadingCheckpoint(const char* filename) {
130130
return checkpoint_reader_->Run(data, this);
131131
}
132132

133-
// Initializes the character set encode/decode mechanism.
134-
// train_flags control training behavior according to the TrainingFlags
135-
// enum, including character set encoding.
136-
// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
137-
// fully initializes the unicharset from the universal unicharsets.
138-
// Note: Call before InitNetwork!
139-
void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset,
140-
const STRING& script_dir, int train_flags) {
141-
EmptyConstructor();
142-
training_flags_ = train_flags;
143-
ccutil_.unicharset.CopyFrom(unicharset);
144-
null_char_ = GetUnicharset().has_special_codes() ? UNICHAR_BROKEN
145-
: GetUnicharset().size();
146-
SetUnicharsetProperties(script_dir);
147-
}
148-
149133
// Initializes the trainer with a network_spec in the network description
150134
// net_flags control network behavior according to the NetworkFlags enum.
151135
// There isn't really much difference between them - only where the effects
@@ -278,9 +262,10 @@ void LSTMTrainer::DebugNetwork() {
278262
// Loads a set of lstmf files that were created using the lstm.train config to
279263
// tesseract into memory ready for training. Returns false if nothing was
280264
// loaded.
281-
bool LSTMTrainer::LoadAllTrainingData(const GenericVector<STRING>& filenames) {
265+
bool LSTMTrainer::LoadAllTrainingData(const GenericVector<STRING>& filenames,
266+
CachingStrategy cache_strategy) {
282267
training_data_.Clear();
283-
return training_data_.LoadDocuments(filenames, CacheStrategy(), file_reader_);
268+
return training_data_.LoadDocuments(filenames, cache_strategy, file_reader_);
284269
}
285270

286271
// Keeps track of best and locally worst char error_rate and launches tests
@@ -908,6 +893,15 @@ bool LSTMTrainer::ReadLocalTrainingDump(const TessdataManager* mgr,
908893
return DeSerialize(mgr, &fp);
909894
}
910895

896+
// Writes the full recognition traineddata to the given filename.
897+
bool LSTMTrainer::SaveTraineddata(const STRING& filename) {
898+
GenericVector<char> recognizer_data;
899+
SaveRecognitionDump(&recognizer_data);
900+
mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
901+
recognizer_data.size());
902+
return mgr_.SaveFile(filename, file_writer_);
903+
}
904+
911905
// Writes the recognizer to memory, so that it can be used for testing later.
912906
void LSTMTrainer::SaveRecognitionDump(GenericVector<char>* data) const {
913907
TFile fp;
@@ -964,52 +958,6 @@ void LSTMTrainer::EmptyConstructor() {
964958
InitIterations();
965959
}
966960

967-
// Sets the unicharset properties using the given script_dir as a source of
968-
// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
969-
// up the recoder_ to simplify the unicharset.
970-
void LSTMTrainer::SetUnicharsetProperties(const STRING& script_dir) {
971-
tprintf("Setting unichar properties\n");
972-
for (int s = 0; s < GetUnicharset().get_script_table_size(); ++s) {
973-
if (strcmp("NULL", GetUnicharset().get_script_from_script_id(s)) == 0)
974-
continue;
975-
// Load the unicharset for the script if available.
976-
STRING filename = script_dir + "/" +
977-
GetUnicharset().get_script_from_script_id(s) +
978-
".unicharset";
979-
UNICHARSET script_set;
980-
GenericVector<char> data;
981-
if ((*file_reader_)(filename, &data) &&
982-
script_set.load_from_inmemory_file(&data[0], data.size())) {
983-
tprintf("Setting properties for script %s\n",
984-
GetUnicharset().get_script_from_script_id(s));
985-
ccutil_.unicharset.SetPropertiesFromOther(script_set);
986-
}
987-
}
988-
if (IsRecoding()) {
989-
STRING filename = script_dir + "/radical-stroke.txt";
990-
GenericVector<char> data;
991-
if ((*file_reader_)(filename, &data)) {
992-
data += '\0';
993-
STRING stroke_table = &data[0];
994-
if (recoder_.ComputeEncoding(GetUnicharset(), null_char_,
995-
&stroke_table)) {
996-
RecodedCharID code;
997-
recoder_.EncodeUnichar(null_char_, &code);
998-
null_char_ = code(0);
999-
// Space should encode as itself.
1000-
recoder_.EncodeUnichar(UNICHAR_SPACE, &code);
1001-
ASSERT_HOST(code(0) == UNICHAR_SPACE);
1002-
return;
1003-
}
1004-
} else {
1005-
tprintf("Failed to load radical-stroke info from: %s\n",
1006-
filename.string());
1007-
}
1008-
}
1009-
training_flags_ |= TF_COMPRESS_UNICHARSET;
1010-
recoder_.SetupPassThrough(GetUnicharset());
1011-
}
1012-
1013961
// Outputs the string and periodically displays the given network inputs
1014962
// as an image in the given window, and the corresponding labels at the
1015963
// corresponding x_starts.

lstm/lstmtrainer.h

+5-17
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,6 @@ class LSTMTrainer : public LSTMRecognizer {
101101
// false in case of failure.
102102
bool TryLoadingCheckpoint(const char* filename);
103103

104-
// Initializes the character set encode/decode mechanism.
105-
// train_flags control training behavior according to the TrainingFlags
106-
// enum, including character set encoding.
107-
// script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
108-
// fully initializes the unicharset from the universal unicharsets.
109-
// Note: Call before InitNetwork!
110-
void InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir,
111-
int train_flags);
112104
// Initializes the character set encode/decode mechanism directly from a
113105
// previously setup traineddata containing dawgs, UNICHARSET and
114106
// UnicharCompress. Note: Call before InitNetwork!
@@ -186,7 +178,8 @@ class LSTMTrainer : public LSTMRecognizer {
186178
// Loads a set of lstmf files that were created using the lstm.train config to
187179
// tesseract into memory ready for training. Returns false if nothing was
188180
// loaded.
189-
bool LoadAllTrainingData(const GenericVector<STRING>& filenames);
181+
bool LoadAllTrainingData(const GenericVector<STRING>& filenames,
182+
CachingStrategy cache_strategy);
190183

191184
// Keeps track of best and locally worst error rate, using internally computed
192185
// values. See MaintainCheckpointsSpecific for more detail.
@@ -315,12 +308,12 @@ class LSTMTrainer : public LSTMRecognizer {
315308
// Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
316309
void SetupCheckpointInfo();
317310

311+
// Writes the full recognition traineddata to the given filename.
312+
bool SaveTraineddata(const STRING& filename);
313+
318314
// Writes the recognizer to memory, so that it can be used for testing later.
319315
void SaveRecognitionDump(GenericVector<char>* data) const;
320316

321-
// Writes current best model to a file, unless it has already been written.
322-
bool SaveBestModel(FileWriter writer) const;
323-
324317
// Returns a suitable filename for a training dump, based on the model_base_,
325318
// the iteration and the error rates.
326319
STRING DumpFilename() const;
@@ -336,11 +329,6 @@ class LSTMTrainer : public LSTMRecognizer {
336329
// Factored sub-constructor sets up reasonable default values.
337330
void EmptyConstructor();
338331

339-
// Sets the unicharset properties using the given script_dir as a source of
340-
// script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
341-
// up the recoder_ to simplify the unicharset.
342-
void SetUnicharsetProperties(const STRING& script_dir);
343-
344332
// Outputs the string and periodically displays the given network inputs
345333
// as an image in the given window, and the corresponding labels at the
346334
// corresponding x_starts.

0 commit comments

Comments
 (0)