@@ -60,14 +60,30 @@ namespace tesseract {
60
60
// is required to allow the base class implementation to do all the work.
61
61
class IntSimdMatrix {
62
62
public:
63
- // Constructor should set the data members to indicate the sizes.
64
- // NOTE: Base constructor public only for test purposes.
65
- IntSimdMatrix ()
66
- : num_outputs_per_register_(1 ),
67
- max_output_registers_ (1 ),
68
- num_inputs_per_register_(1 ),
69
- num_inputs_per_group_(1 ),
70
- num_input_groups_(1 ) {}
63
+ // Function to compute part of a matrix.vector multiplication. The weights
64
+ // are in a very specific order (see above) in w, which is multiplied by
65
+ // u of length num_in, to produce output v after scaling the integer results
66
+ // by the corresponding member of scales.
67
+ // The amount of w and scales consumed is fixed and not available to the
68
+ // caller. The number of outputs written to v will be at most num_out.
69
+ typedef void (*PartialFunc)(const int8_t * w, const double * scales,
70
+ const int8_t * u, int num_in, int num_out,
71
+ double * v);
72
+
73
+ IntSimdMatrix (int num_outputs_per_register, int max_output_registers, int num_inputs_per_register, int num_inputs_per_group, int num_input_groups, std::vector<PartialFunc> partial_funcs) :
74
+ // Number of 32 bit outputs held in each register.
75
+ num_outputs_per_register_ (num_outputs_per_register),
76
+ // Maximum number of registers that we will use to hold outputs.
77
+ max_output_registers_ (max_output_registers),
78
+ // Number of 8 bit inputs in the inputs register.
79
+ num_inputs_per_register_ (num_inputs_per_register),
80
+ // Number of inputs in each weight group.
81
+ num_inputs_per_group_ (num_inputs_per_group),
82
+ // Number of groups of inputs to be broadcast.
83
+ num_input_groups_ (num_input_groups),
84
+ // A series of functions to compute a partial result.
85
+ partial_funcs_ (partial_funcs)
86
+ {}
71
87
72
88
// Factory makes and returns an IntSimdMatrix (sub)class of the best
73
89
// available type for the current architecture.
@@ -100,16 +116,6 @@ class IntSimdMatrix {
100
116
double * v) const ;
101
117
102
118
protected:
103
- // Function to compute part of a matrix.vector multiplication. The weights
104
- // are in a very specific order (see above) in w, which is multiplied by
105
- // u of length num_in, to produce output v after scaling the integer results
106
- // by the corresponding member of scales.
107
- // The amount of w and scales consumed is fixed and not available to the
108
- // caller. The number of outputs written to v will be at most num_out.
109
- typedef void (*PartialFunc)(const int8_t * w, const double * scales,
110
- const int8_t * u, int num_in, int num_out,
111
- double * v);
112
-
113
119
// Rounds the input up to a multiple of the given factor.
114
120
static int Roundup (int input, int factor) {
115
121
return (input + factor - 1 ) / factor * factor;
0 commit comments