Skip to content

Commit b86b4fa

Browse files
committed
Better fix for re-enabling training
1 parent 0afd593 commit b86b4fa

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

lstm/fullyconnected.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ void FullyConnected::SetEnableTraining(TrainingState state) {
6565
// Temp disable only from enabled.
6666
if (training_ == TS_ENABLED) training_ = state;
6767
} else {
68-
if (state == TS_ENABLED && training_ == TS_DISABLED)
68+
if (state == TS_ENABLED && training_ != TS_ENABLED)
6969
weights_.InitBackward();
7070
training_ = state;
7171
}

lstm/lstm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ void LSTM::SetEnableTraining(TrainingState state) {
113113
// Temp disable only from enabled.
114114
if (training_ == TS_ENABLED) training_ = state;
115115
} else {
116-
if (state == TS_ENABLED && training_ == TS_DISABLED) {
116+
if (state == TS_ENABLED && training_ != TS_ENABLED) {
117117
for (int w = 0; w < WT_COUNT; ++w) {
118118
if (w == GFS && !Is2D()) continue;
119119
gate_weights_[w].InitBackward();

0 commit comments

Comments
 (0)