Skip to content

Commit b459990

Browse files
committed
LSTM char_whitelist/blacklist (6ac2ff0): multi-code chars
- move decision from ComputeTopN to ContinueContext, where it belongs: block context continuations which emit final codes translating to disabled unichar_ids. (The normal logic for fallback from top2 > top2 > rest will apply.) - pass UNICHARSET refs appropriately
1 parent 8012d5e commit b459990

File tree

2 files changed

+17
-21
lines changed

2 files changed

+17
-21
lines changed

src/lstm/recodebeam.cpp

+13-17
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void RecodeBeamSearch::Decode(const NetworkIO& output, double dict_ratio,
8787
if (lstm_choice_mode)
8888
timesteps.clear();
8989
for (int t = 0; t < width; ++t) {
90-
ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0], charset);
90+
ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
9191
DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
9292
charset);
9393
if (lstm_choice_mode) {
@@ -102,7 +102,7 @@ void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float>& output,
102102
beam_size_ = 0;
103103
int width = output.dim1();
104104
for (int t = 0; t < width; ++t) {
105-
ComputeTopN(output[t], output.dim2(), kBeamWidths[0], charset);
105+
ComputeTopN(output[t], output.dim2(), kBeamWidths[0]);
106106
DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
107107
}
108108
}
@@ -456,19 +456,12 @@ WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space,
456456
// Fills top_n_flags_ with bools that are true iff the corresponding output
457457
// is one of the top_n.
458458
void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs,
459-
int top_n, const UNICHARSET* charset) {
459+
int top_n) {
460460
top_n_flags_.init_to_size(num_outputs, TN_ALSO_RAN);
461461
top_code_ = -1;
462462
second_code_ = -1;
463463
top_heap_.clear();
464464
for (int i = 0; i < num_outputs; ++i) {
465-
// Decode label via recoder_.
466-
RecodedCharID code;
467-
code.Set(0, i);
468-
int label = recoder_.DecodeUnichar(code);
469-
if (label != INVALID_UNICHAR_ID && // not part of a bigger code.
470-
!charset->get_enabled(label)) // disabled
471-
continue;
472465
if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) {
473466
TopPair entry(outputs[i], i);
474467
top_heap_.Push(&entry);
@@ -505,10 +498,10 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
505498
if (t == 0) {
506499
// The first step can only use singles and initials.
507500
ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
508-
dict_ratio, cert_offset, worst_dict_cert, step);
501+
charset, dict_ratio, cert_offset, worst_dict_cert, step);
509502
if (dict_ != nullptr) {
510-
ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
511-
TN_TOP2, dict_ratio, cert_offset, worst_dict_cert, step);
503+
ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs, TN_TOP2,
504+
charset, dict_ratio, cert_offset, worst_dict_cert, step);
512505
}
513506
} else {
514507
RecodeBeam* prev = beam_[t - 1];
@@ -540,9 +533,8 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
540533
// best first, but it comes before a lot of the worst, so it is slightly
541534
// more efficient than going forwards.
542535
for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
543-
ContinueContext(&prev->beams_[index].get(i).data, index, outputs,
544-
top_n, dict_ratio, cert_offset, worst_dict_cert,
545-
step);
536+
ContinueContext(&prev->beams_[index].get(i).data, index, outputs, top_n,
537+
charset, dict_ratio, cert_offset, worst_dict_cert, step);
546538
}
547539
}
548540
for (int index = 0; index < kNumBeams; ++index) {
@@ -569,7 +561,9 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
569561
// choices for which top_n_flags[index] == top_n_flag.
570562
void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int index,
571563
const float* outputs,
572-
TopNState top_n_flag, double dict_ratio,
564+
TopNState top_n_flag,
565+
const UNICHARSET* charset,
566+
double dict_ratio,
573567
double cert_offset,
574568
double worst_dict_cert,
575569
RecodeBeam* step) {
@@ -632,6 +626,8 @@ void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int index,
632626
int unichar_id = recoder_.DecodeUnichar(full_code);
633627
// Map the null char to INVALID.
634628
if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID;
629+
if (unichar_id != INVALID_UNICHAR_ID && !charset->get_enabled(unichar_id))
630+
continue; // disabled by whitelist/blacklist
635631
ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
636632
use_dawgs, NC_ANYTHING, prev, step);
637633
if (top_n_flag == TN_TOP2 && code != null_char_) {

src/lstm/recodebeam.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class RecodeBeamSearch {
293293

294294
// Fills top_n_flags_ with bools that are true iff the corresponding output
295295
// is one of the top_n.
296-
void ComputeTopN(const float* outputs, int num_outputs, int top_n, const UNICHARSET* unicharset);
296+
void ComputeTopN(const float* outputs, int num_outputs, int top_n);
297297

298298
// Adds the computation for the current time-step to the beam. Call at each
299299
// time-step in sequence from left to right. outputs is the activation vector
@@ -310,9 +310,9 @@ class RecodeBeamSearch {
310310
// using the given network outputs to provide scores to the choices. Uses only
311311
// those choices for which top_n_flags[code] == top_n_flag.
312312
void ContinueContext(const RecodeNode* prev, int index, const float* outputs,
313-
TopNState top_n_flag, double dict_ratio,
314-
double cert_offset, double worst_dict_cert,
315-
RecodeBeam* step);
313+
TopNState top_n_flag, const UNICHARSET* unicharset,
314+
double dict_ratio, double cert_offset,
315+
double worst_dict_cert, RecodeBeam* step);
316316
// Continues for a new unichar, using dawg or non-dawg as per flag.
317317
void ContinueUnichar(int code, int unichar_id, float cert,
318318
float worst_dict_cert, float dict_ratio, bool use_dawgs,

0 commit comments

Comments
 (0)