Skip to content

Commit 26be7c5

Browse files
committed
Use constructor with parameters for IntSimdMatrix
Signed-off-by: Stefan Weil <sw@weilnetz.de>
1 parent e237a38 commit 26be7c5

File tree

5 files changed

+37
-30
lines changed

5 files changed

+37
-30
lines changed

src/arch/intsimdmatrix.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ const IntSimdMatrix* IntSimdMatrix::GetFastestMultiplier() {
3636
multiplier = new IntSimdMatrixSSE();
3737
} else {
3838
// Default c++ implementation.
39-
multiplier = new IntSimdMatrix();
39+
multiplier = new IntSimdMatrix(1, 1, 1, 1, 1, {});
4040
}
4141
return multiplier;
4242
}

src/arch/intsimdmatrix.h

+24-18
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,30 @@ namespace tesseract {
6060
// is required to allow the base class implementation to do all the work.
6161
class IntSimdMatrix {
6262
public:
63-
// Constructor should set the data members to indicate the sizes.
64-
// NOTE: Base constructor public only for test purposes.
65-
IntSimdMatrix()
66-
: num_outputs_per_register_(1),
67-
max_output_registers_(1),
68-
num_inputs_per_register_(1),
69-
num_inputs_per_group_(1),
70-
num_input_groups_(1) {}
63+
// Function to compute part of a matrix.vector multiplication. The weights
64+
// are in a very specific order (see above) in w, which is multiplied by
65+
// u of length num_in, to produce output v after scaling the integer results
66+
// by the corresponding member of scales.
67+
// The amount of w and scales consumed is fixed and not available to the
68+
// caller. The number of outputs written to v will be at most num_out.
69+
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
70+
const int8_t* u, int num_in, int num_out,
71+
double* v);
72+
73+
IntSimdMatrix(int num_outputs_per_register, int max_output_registers, int num_inputs_per_register, int num_inputs_per_group, int num_input_groups, std::vector<PartialFunc> partial_funcs) :
74+
// Number of 32 bit outputs held in each register.
75+
num_outputs_per_register_(num_outputs_per_register),
76+
// Maximum number of registers that we will use to hold outputs.
77+
max_output_registers_(max_output_registers),
78+
// Number of 8 bit inputs in the inputs register.
79+
num_inputs_per_register_(num_inputs_per_register),
80+
// Number of inputs in each weight group.
81+
num_inputs_per_group_(num_inputs_per_group),
82+
// Number of groups of inputs to be broadcast.
83+
num_input_groups_(num_input_groups),
84+
// A series of functions to compute a partial result.
85+
partial_funcs_(partial_funcs)
86+
{}
7187

7288
// Factory makes and returns an IntSimdMatrix (sub)class of the best
7389
// available type for the current architecture.
@@ -100,16 +116,6 @@ class IntSimdMatrix {
100116
double* v) const;
101117

102118
protected:
103-
// Function to compute part of a matrix.vector multiplication. The weights
104-
// are in a very specific order (see above) in w, which is multiplied by
105-
// u of length num_in, to produce output v after scaling the integer results
106-
// by the corresponding member of scales.
107-
// The amount of w and scales consumed is fixed and not available to the
108-
// caller. The number of outputs written to v will be at most num_out.
109-
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
110-
const int8_t* u, int num_in, int num_out,
111-
double* v);
112-
113119
// Rounds the input up to a multiple of the given factor.
114120
static int Roundup(int input, int factor) {
115121
return (input + factor - 1) / factor * factor;

src/arch/intsimdmatrixavx2.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -269,16 +269,14 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
269269
namespace tesseract {
270270
#endif // __AVX2__
271271

272-
IntSimdMatrixAVX2::IntSimdMatrixAVX2() {
272+
IntSimdMatrixAVX2::IntSimdMatrixAVX2()
273273
#ifdef __AVX2__
274-
num_outputs_per_register_ = kNumOutputsPerRegister;
275-
max_output_registers_ = kMaxOutputRegisters;
276-
num_inputs_per_register_ = kNumInputsPerRegister;
277-
num_inputs_per_group_ = kNumInputsPerGroup;
278-
num_input_groups_ = kNumInputGroups;
279-
partial_funcs_ = {PartialMatrixDotVector64, PartialMatrixDotVector32,
280-
PartialMatrixDotVector16, PartialMatrixDotVector8};
274+
: IntSimdMatrix(kNumOutputsPerRegister, kMaxOutputRegisters, kNumInputsPerRegister, kNumInputsPerGroup, kNumInputGroups, {PartialMatrixDotVector64, PartialMatrixDotVector32,
275+
PartialMatrixDotVector16, PartialMatrixDotVector8})
276+
#else
277+
: IntSimdMatrix(1, 1, 1, 1, 1, {})
281278
#endif // __AVX2__
279+
{
282280
}
283281

284282
} // namespace tesseract.

src/arch/intsimdmatrixsse.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
3333
}
3434
#endif // __SSE4_1__
3535

36-
IntSimdMatrixSSE::IntSimdMatrixSSE() {
36+
IntSimdMatrixSSE::IntSimdMatrixSSE()
3737
#ifdef __SSE4_1__
38-
partial_funcs_ = {PartialMatrixDotVector1};
38+
: IntSimdMatrix(1, 1, 1, 1, 1, {PartialMatrixDotVector1})
39+
#else
40+
: IntSimdMatrix(1, 1, 1, 1, 1, {})
3941
#endif // __SSE4_1__
42+
{
4043
}
4144

4245
} // namespace tesseract.

unittest/intsimdmatrix_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class IntSimdMatrixTest : public ::testing::Test {
8282
}
8383

8484
TRand random_;
85-
IntSimdMatrix base_;
85+
IntSimdMatrix base_ = IntSimdMatrix(1, 1, 1, 1, 1, {});
8686
};
8787

8888
// Test the C++ implementation without SIMD.

0 commit comments

Comments
 (0)