Skip to content

Commit d36231e

Browse files
committed
Set best or user selected IntSimdMatrix
Signed-off-by: Stefan Weil <sw@weilnetz.de>
1 parent 605b4d6 commit d36231e

File tree

7 files changed

+25
-50
lines changed

7 files changed

+25
-50
lines changed

src/arch/intsimdmatrix.cpp

+1-16
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,10 @@
2323

2424
namespace tesseract {
2525

26+
const IntSimdMatrix* IntSimdMatrix::intSimdMatrix = nullptr;
2627
const IntSimdMatrix IntSimdMatrix::IntSimdMatrixNative =
2728
IntSimdMatrix(1, 1, 1, 1, 1, {});
2829

29-
// Factory makes and returns an IntSimdMatrix (sub)class of the best
30-
// available type for the current architecture.
31-
/* static */
32-
const IntSimdMatrix* IntSimdMatrix::GetFastestMultiplier() {
33-
const IntSimdMatrix* multiplier;
34-
if (SIMDDetect::IsAVX2Available()) {
35-
multiplier = &IntSimdMatrixAVX2;
36-
} else if (SIMDDetect::IsSSEAvailable()) {
37-
multiplier = &IntSimdMatrixSSE;
38-
} else {
39-
// Default c++ implementation.
40-
multiplier = &IntSimdMatrixNative;
41-
}
42-
return multiplier;
43-
}
44-
4530
// Computes a reshaped copy of the weight matrix w. If there are no
4631
// partial_funcs_, it does nothing.
4732
void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w, std::vector<int8_t>& shaped_w) const {

src/arch/intsimdmatrix.h

+1-4
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,6 @@ class IntSimdMatrix {
8585
partial_funcs_(partial_funcs)
8686
{}
8787

88-
// Factory makes and returns an IntSimdMatrix (sub)class of the best
89-
// available type for the current architecture.
90-
static const IntSimdMatrix* GetFastestMultiplier();
91-
9288
// Computes a reshaped copy of the weight matrix w. If there are no
9389
// partial_funcs_, it does nothing.
9490
void Init(const GENERIC_2D_ARRAY<int8_t>& w, std::vector<int8_t>& shaped_w) const;
@@ -115,6 +111,7 @@ class IntSimdMatrix {
115111
const GenericVector<double>& scales, const int8_t* u,
116112
double* v) const;
117113

114+
static const IntSimdMatrix* intSimdMatrix;
118115
static const IntSimdMatrix IntSimdMatrixAVX2;
119116
static const IntSimdMatrix IntSimdMatrixSSE;
120117
static const IntSimdMatrix IntSimdMatrixNative;

src/arch/simddetect.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "dotproduct.h"
2020
#include "dotproductavx.h"
2121
#include "dotproductsse.h"
22+
#include "intsimdmatrix.h" // for IntSimdMatrix
2223
#include "params.h" // for STRING_VAR
2324
#include "tprintf.h" // for tprintf
2425

@@ -68,8 +69,9 @@ static double DotProductGeneric(const double* u, const double* v, int n) {
6869
return total;
6970
}
7071

71-
static void SetDotProduct(DotProductFunction function) {
72-
DotProduct = function;
72+
static void SetDotProduct(DotProductFunction f, const IntSimdMatrix* m = nullptr) {
73+
DotProduct = f;
74+
IntSimdMatrix::intSimdMatrix = m;
7375
}
7476

7577
// Constructor.
@@ -126,12 +128,12 @@ SIMDDetect::SIMDDetect() {
126128
#if defined(AVX)
127129
} else if (avx_available_) {
128130
// AVX detected.
129-
SetDotProduct(DotProductAVX);
131+
SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2);
130132
#endif
131133
#if defined(SSE4_1)
132134
} else if (sse_available_) {
133135
// SSE detected.
134-
SetDotProduct(DotProductSSE);
136+
SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE);
135137
#endif
136138
}
137139
}
@@ -153,13 +155,13 @@ void SIMDDetect::Update() {
153155
#if defined(AVX)
154156
} else if (!strcmp(dotproduct.string(), "avx")) {
155157
// AVX selected by config variable.
156-
SetDotProduct(DotProductAVX);
158+
SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2);
157159
dotproduct_method = "avx";
158160
#endif
159161
#if defined(SSE4_1)
160162
} else if (!strcmp(dotproduct.string(), "sse")) {
161163
// SSE selected by config variable.
162-
SetDotProduct(DotProductSSE);
164+
SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE);
163165
dotproduct_method = "sse";
164166
#endif
165167
} else {

src/lstm/networkio.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,6 @@ const float kMinCertainty = -20.0f;
3131
// Probability corresponding to kMinCertainty.
3232
const float kMinProb = exp(kMinCertainty);
3333

34-
// Holds the optimal integer multiplier for this machine.
35-
// This is a leaked, lazily initialized singleton, and is used for computing
36-
// padding to apply to i_ for SIMD use.
37-
const IntSimdMatrix* NetworkIO::multiplier_ = nullptr;
38-
3934
// Resizes to a specific size as a 2-d temp buffer. No batches, no y-dim.
4035
void NetworkIO::Resize2d(bool int_mode, int width, int num_features) {
4136
stride_map_ = StrideMap();
@@ -985,9 +980,12 @@ void NetworkIO::ClipVector(int t, float range) {
985980
// for the SIMD operations to be safe.
986981
/* static */
987982
int NetworkIO::GetPadding(int num_features) {
988-
if (multiplier_ == nullptr)
989-
multiplier_ = IntSimdMatrix::GetFastestMultiplier();
990-
return multiplier_->RoundInputs(num_features) - num_features;
983+
int padding = 0;
984+
if (IntSimdMatrix::intSimdMatrix) {
985+
padding =
986+
IntSimdMatrix::intSimdMatrix->RoundInputs(num_features) - num_features;
987+
}
988+
return padding;
991989
}
992990

993991
} // namespace tesseract.

src/lstm/networkio.h

-4
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,6 @@ class NetworkIO {
338338
bool int_mode_;
339339
// Stride for 2d input data.
340340
StrideMap stride_map_;
341-
// Holds the optimal integer multiplier for this machine.
342-
// This is a leaked, lazily initialized singleton, and is used for computing
343-
// padding to apply to i_ for SIMD use.
344-
static const IntSimdMatrix* multiplier_;
345341
};
346342

347343
} // namespace tesseract.

src/lstm/weightmatrix.cpp

+5-6
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ void WeightMatrix::ConvertToInt() {
143143
}
144144
wf_.Resize(1, 1, 0.0);
145145
int_mode_ = true;
146-
multiplier_ = IntSimdMatrix::GetFastestMultiplier();
147-
multiplier_->Init(wi_, shaped_w_);
146+
if (IntSimdMatrix::intSimdMatrix)
147+
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
148148
}
149149

150150
// Allocates any needed memory for running Backward, and zeroes the deltas,
@@ -196,8 +196,8 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
196196
if (int_mode_) {
197197
if (!wi_.DeSerialize(fp)) return false;
198198
if (!scales_.DeSerialize(fp)) return false;
199-
multiplier_ = IntSimdMatrix::GetFastestMultiplier();
200-
multiplier_->Init(wi_, shaped_w_);
199+
if (IntSimdMatrix::intSimdMatrix)
200+
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
201201
} else {
202202
if (!wf_.DeSerialize(fp)) return false;
203203
if (training) {
@@ -245,8 +245,7 @@ void WeightMatrix::MatrixDotVector(const double* u, double* v) const {
245245

246246
void WeightMatrix::MatrixDotVector(const int8_t* u, double* v) const {
247247
assert(int_mode_);
248-
assert(multiplier_ != nullptr);
249-
multiplier_->MatrixDotVector(wi_, shaped_w_, scales_, u, v);
248+
IntSimdMatrix::intSimdMatrix->MatrixDotVector(wi_, shaped_w_, scales_, u, v);
250249
}
251250

252251
// MatrixDotVector for peep weights, MultiplyAccumulate adds the

src/lstm/weightmatrix.h

+4-6
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class TransposedArray : public GENERIC_2D_ARRAY<double> {
6464
// backward steps with the matrix and updates to the weights.
6565
class WeightMatrix {
6666
public:
67-
WeightMatrix() : int_mode_(false), use_adam_(false), multiplier_(nullptr) {}
67+
WeightMatrix() : int_mode_(false), use_adam_(false) {}
6868
// Sets up the network for training. Initializes weights using weights of
6969
// scale `range` picked according to the random number generator `randomizer`.
7070
// Note the order is outputs, inputs, as this is the order of indices to
@@ -85,13 +85,13 @@ class WeightMatrix {
8585
// Scale so the max absolute value becomes INT8_MAX.
8686
// Round to integer.
8787
// Store a multiplicative scale factor (as a float) that will reproduce
88-
// the original value, subject to rounding errors.
88+
// the original value, subject to rounding errors.
8989
void ConvertToInt();
9090
// Returns the size rounded up to an internal factor used by the SIMD
9191
// implementation for its input.
9292
int RoundInputs(int size) const {
93-
if (multiplier_ == nullptr) return size;
94-
return multiplier_->RoundInputs(size);
93+
if (!int_mode_ || !IntSimdMatrix::intSimdMatrix) return size;
94+
return IntSimdMatrix::intSimdMatrix->RoundInputs(size);
9595
}
9696

9797
// Accessors.
@@ -178,8 +178,6 @@ class WeightMatrix {
178178
GENERIC_2D_ARRAY<double> dw_sq_sum_;
179179
// The weights matrix reorganized in whatever way suits this instance.
180180
std::vector<int8_t> shaped_w_;
181-
// Holds the optimal integer multiplier for this machine.
182-
const IntSimdMatrix* multiplier_;
183181
};
184182

185183
} // namespace tesseract.

0 commit comments

Comments
 (0)