@@ -87,7 +87,7 @@ void RecodeBeamSearch::Decode(const NetworkIO& output, double dict_ratio,
87
87
if (lstm_choice_mode)
88
88
timesteps.clear ();
89
89
for (int t = 0 ; t < width; ++t) {
90
- ComputeTopN (output.f (t), output.NumFeatures (), kBeamWidths [0 ], charset );
90
+ ComputeTopN (output.f (t), output.NumFeatures (), kBeamWidths [0 ]);
91
91
DecodeStep (output.f (t), t, dict_ratio, cert_offset, worst_dict_cert,
92
92
charset);
93
93
if (lstm_choice_mode) {
@@ -102,7 +102,7 @@ void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY<float>& output,
102
102
beam_size_ = 0 ;
103
103
int width = output.dim1 ();
104
104
for (int t = 0 ; t < width; ++t) {
105
- ComputeTopN (output[t], output.dim2 (), kBeamWidths [0 ], charset );
105
+ ComputeTopN (output[t], output.dim2 (), kBeamWidths [0 ]);
106
106
DecodeStep (output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
107
107
}
108
108
}
@@ -456,19 +456,12 @@ WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space,
456
456
// Fills top_n_flags_ with bools that are true iff the corresponding output
457
457
// is one of the top_n.
458
458
void RecodeBeamSearch::ComputeTopN (const float * outputs, int num_outputs,
459
- int top_n, const UNICHARSET* charset ) {
459
+ int top_n) {
460
460
top_n_flags_.init_to_size (num_outputs, TN_ALSO_RAN);
461
461
top_code_ = -1 ;
462
462
second_code_ = -1 ;
463
463
top_heap_.clear ();
464
464
for (int i = 0 ; i < num_outputs; ++i) {
465
- // Decode label via recoder_.
466
- RecodedCharID code;
467
- code.Set (0 , i);
468
- int label = recoder_.DecodeUnichar (code);
469
- if (label != INVALID_UNICHAR_ID && // not part of a bigger code.
470
- !charset->get_enabled (label)) // disabled
471
- continue ;
472
465
if (top_heap_.size () < top_n || outputs[i] > top_heap_.PeekTop ().key ) {
473
466
TopPair entry (outputs[i], i);
474
467
top_heap_.Push (&entry);
@@ -505,10 +498,10 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
505
498
if (t == 0 ) {
506
499
// The first step can only use singles and initials.
507
500
ContinueContext (nullptr , BeamIndex (false , NC_ANYTHING, 0 ), outputs, TN_TOP2,
508
- dict_ratio, cert_offset, worst_dict_cert, step);
501
+ charset, dict_ratio, cert_offset, worst_dict_cert, step);
509
502
if (dict_ != nullptr ) {
510
- ContinueContext (nullptr , BeamIndex (true , NC_ANYTHING, 0 ), outputs,
511
- TN_TOP2 , dict_ratio, cert_offset, worst_dict_cert, step);
503
+ ContinueContext (nullptr , BeamIndex (true , NC_ANYTHING, 0 ), outputs, TN_TOP2,
504
+ charset , dict_ratio, cert_offset, worst_dict_cert, step);
512
505
}
513
506
} else {
514
507
RecodeBeam* prev = beam_[t - 1 ];
@@ -540,9 +533,8 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
540
533
// best first, but it comes before a lot of the worst, so it is slightly
541
534
// more efficient than going forwards.
542
535
for (int i = prev->beams_ [index ].size () - 1 ; i >= 0 ; --i) {
543
- ContinueContext (&prev->beams_ [index ].get (i).data , index , outputs,
544
- top_n, dict_ratio, cert_offset, worst_dict_cert,
545
- step);
536
+ ContinueContext (&prev->beams_ [index ].get (i).data , index , outputs, top_n,
537
+ charset, dict_ratio, cert_offset, worst_dict_cert, step);
546
538
}
547
539
}
548
540
for (int index = 0 ; index < kNumBeams ; ++index ) {
@@ -569,7 +561,9 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
569
561
// choices for which top_n_flags[index] == top_n_flag.
570
562
void RecodeBeamSearch::ContinueContext (const RecodeNode* prev, int index,
571
563
const float * outputs,
572
- TopNState top_n_flag, double dict_ratio,
564
+ TopNState top_n_flag,
565
+ const UNICHARSET* charset,
566
+ double dict_ratio,
573
567
double cert_offset,
574
568
double worst_dict_cert,
575
569
RecodeBeam* step) {
@@ -632,6 +626,8 @@ void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int index,
632
626
int unichar_id = recoder_.DecodeUnichar (full_code);
633
627
// Map the null char to INVALID.
634
628
if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID;
629
+ if (unichar_id != INVALID_UNICHAR_ID && !charset->get_enabled (unichar_id))
630
+ continue ; // disabled by whitelist/blacklist
635
631
ContinueUnichar (code, unichar_id, cert, worst_dict_cert, dict_ratio,
636
632
use_dawgs, NC_ANYTHING, prev, step);
637
633
if (top_n_flag == TN_TOP2 && code != null_char_) {
0 commit comments