Skip to content

Commit b7b8dba

Browse files
committed
LSTMTrainer: Use new serialization API
Improve also portability by using int32_t instead of int for a serialized member variable. Signed-off-by: Stefan Weil <sw@weilnetz.de>
1 parent 1dcda1a commit b7b8dba

File tree

2 files changed

+32
-62
lines changed

2 files changed

+32
-62
lines changed

src/lstm/lstmtrainer.cpp

+30-60
Original file line numberDiff line numberDiff line change
@@ -431,38 +431,25 @@ bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
431431
bool LSTMTrainer::Serialize(SerializeAmount serialize_amount,
432432
const TessdataManager* mgr, TFile* fp) const {
433433
if (!LSTMRecognizer::Serialize(mgr, fp)) return false;
434-
if (fp->FWrite(&learning_iteration_, sizeof(learning_iteration_), 1) != 1)
435-
return false;
436-
if (fp->FWrite(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) !=
437-
1)
438-
return false;
439-
if (fp->FWrite(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) return false;
440-
if (fp->FWrite(&last_perfect_training_iteration_,
441-
sizeof(last_perfect_training_iteration_), 1) != 1)
442-
return false;
434+
if (!fp->Serialize(&learning_iteration_)) return false;
435+
if (!fp->Serialize(&prev_sample_iteration_)) return false;
436+
if (!fp->Serialize(&perfect_delay_)) return false;
437+
if (!fp->Serialize(&last_perfect_training_iteration_)) return false;
443438
for (int i = 0; i < ET_COUNT; ++i) {
444439
if (!error_buffers_[i].Serialize(fp)) return false;
445440
}
446-
if (fp->FWrite(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
447-
if (fp->FWrite(&training_stage_, sizeof(training_stage_), 1) != 1)
448-
return false;
441+
if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) return false;
442+
if (!fp->Serialize(&training_stage_)) return false;
449443
uint8_t amount = serialize_amount;
450-
if (fp->FWrite(&amount, sizeof(amount), 1) != 1) return false;
444+
if (!fp->Serialize(&amount)) return false;
451445
if (serialize_amount == LIGHT) return true; // We are done.
452-
if (fp->FWrite(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
453-
return false;
454-
if (fp->FWrite(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
455-
return false;
456-
if (fp->FWrite(&best_iteration_, sizeof(best_iteration_), 1) != 1)
457-
return false;
458-
if (fp->FWrite(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
459-
return false;
460-
if (fp->FWrite(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
461-
return false;
462-
if (fp->FWrite(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
463-
return false;
464-
if (fp->FWrite(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
465-
return false;
446+
if (!fp->Serialize(&best_error_rate_)) return false;
447+
if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
448+
if (!fp->Serialize(&best_iteration_)) return false;
449+
if (!fp->Serialize(&worst_error_rate_)) return false;
450+
if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
451+
if (!fp->Serialize(&worst_iteration_)) return false;
452+
if (!fp->Serialize(&stall_iteration_)) return false;
466453
if (!best_model_data_.Serialize(fp)) return false;
467454
if (!worst_model_data_.Serialize(fp)) return false;
468455
if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp))
@@ -473,16 +460,14 @@ bool LSTMTrainer::Serialize(SerializeAmount serialize_amount,
473460
if (!sub_data.Serialize(fp)) return false;
474461
if (!best_error_history_.Serialize(fp)) return false;
475462
if (!best_error_iterations_.Serialize(fp)) return false;
476-
if (fp->FWrite(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
477-
return false;
478-
return true;
463+
return fp->Serialize(&improvement_steps_);
479464
}
480465

481466
// Reads from the given file. Returns false in case of error.
482467
// NOTE: It is assumed that the trainer is never read cross-endian.
483468
bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
484469
if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false;
485-
if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) {
470+
if (!fp->DeSerialize(&learning_iteration_)) {
486471
// Special case. If we successfully decoded the recognizer, but fail here
487472
// then it means we were just given a recognizer, so issue a warning and
488473
// allow it.
@@ -491,37 +476,24 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
491476
network_->SetEnableTraining(TS_ENABLED);
492477
return true;
493478
}
494-
if (fp->FReadEndian(&prev_sample_iteration_, sizeof(prev_sample_iteration_),
495-
1) != 1)
496-
return false;
497-
if (fp->FReadEndian(&perfect_delay_, sizeof(perfect_delay_), 1) != 1)
498-
return false;
499-
if (fp->FReadEndian(&last_perfect_training_iteration_,
500-
sizeof(last_perfect_training_iteration_), 1) != 1)
501-
return false;
479+
if (!fp->DeSerialize(&prev_sample_iteration_)) return false;
480+
if (!fp->DeSerialize(&perfect_delay_)) return false;
481+
if (!fp->DeSerialize(&last_perfect_training_iteration_)) return false;
502482
for (int i = 0; i < ET_COUNT; ++i) {
503483
if (!error_buffers_[i].DeSerialize(fp)) return false;
504484
}
505-
if (fp->FRead(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
506-
if (fp->FReadEndian(&training_stage_, sizeof(training_stage_), 1) != 1)
507-
return false;
485+
if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) return false;
486+
if (!fp->DeSerialize(&training_stage_)) return false;
508487
uint8_t amount;
509-
if (fp->FRead(&amount, sizeof(amount), 1) != 1) return false;
488+
if (!fp->DeSerialize(&amount)) return false;
510489
if (amount == LIGHT) return true; // Don't read the rest.
511-
if (fp->FReadEndian(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
512-
return false;
513-
if (fp->FReadEndian(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
514-
return false;
515-
if (fp->FReadEndian(&best_iteration_, sizeof(best_iteration_), 1) != 1)
516-
return false;
517-
if (fp->FReadEndian(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
518-
return false;
519-
if (fp->FReadEndian(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
520-
return false;
521-
if (fp->FReadEndian(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
522-
return false;
523-
if (fp->FReadEndian(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
524-
return false;
490+
if (!fp->DeSerialize(&best_error_rate_)) return false;
491+
if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
492+
if (!fp->DeSerialize(&best_iteration_)) return false;
493+
if (!fp->DeSerialize(&worst_error_rate_)) return false;
494+
if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
495+
if (!fp->DeSerialize(&worst_iteration_)) return false;
496+
if (!fp->DeSerialize(&stall_iteration_)) return false;
525497
if (!best_model_data_.DeSerialize(fp)) return false;
526498
if (!worst_model_data_.DeSerialize(fp)) return false;
527499
if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
@@ -536,9 +508,7 @@ bool LSTMTrainer::DeSerialize(const TessdataManager* mgr, TFile* fp) {
536508
}
537509
if (!best_error_history_.DeSerialize(fp)) return false;
538510
if (!best_error_iterations_.DeSerialize(fp)) return false;
539-
if (fp->FReadEndian(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
540-
return false;
541-
return true;
511+
return fp->DeSerialize(&improvement_steps_);
542512
}
543513

544514
// De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the

src/lstm/lstmtrainer.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class LSTMTrainer : public LSTMRecognizer {
147147
return best_iteration_;
148148
}
149149
int learning_iteration() const { return learning_iteration_; }
150-
int improvement_steps() const { return improvement_steps_; }
150+
int32_t improvement_steps() const { return improvement_steps_; }
151151
void set_perfect_delay(int delay) { perfect_delay_ = delay; }
152152
const GenericVector<char>& best_trainer() const { return best_trainer_; }
153153
// Returns the error that was just calculated by PrepareForBackward.
@@ -457,7 +457,7 @@ class LSTMTrainer : public LSTMRecognizer {
457457
GenericVector<double> best_error_history_;
458458
GenericVector<int> best_error_iterations_;
459459
// Number of iterations since the best_error_rate_ was 2% more than it is now.
460-
int improvement_steps_;
460+
int32_t improvement_steps_;
461461
// Number of iterations that yielded a non-zero delta error and thus provided
462462
// significant learning. learning_iteration_ <= training_iteration_.
463463
// learning_iteration_ is used to measure rate of learning progress.

0 commit comments

Comments
 (0)