@@ -58,35 +58,8 @@ namespace tesseract {
58
58
// NOTE that, although the subclasses execute on different SIMD hardware, no
59
59
// virtual methods are needed, as the constructor sets up everything that
60
60
// is required to allow the base class implementation to do all the work.
61
- class IntSimdMatrix {
62
- public:
63
- // Function to compute part of a matrix.vector multiplication. The weights
64
- // are in a very specific order (see above) in w, which is multiplied by
65
- // u of length num_in, to produce output v after scaling the integer results
66
- // by the corresponding member of scales.
67
- // The amount of w and scales consumed is fixed and not available to the
68
- // caller. The number of outputs written to v will be at most num_out.
69
- typedef void (*PartialFunc)(const int8_t * w, const double * scales,
70
- const int8_t * u, int num_in, int num_out,
71
- double * v);
72
-
73
- IntSimdMatrix (int num_outputs_per_register, int max_output_registers, int num_inputs_per_register, int num_inputs_per_group, int num_input_groups, std::vector<PartialFunc> partial_funcs) :
74
- // Number of 32 bit outputs held in each register.
75
- num_outputs_per_register_ (num_outputs_per_register),
76
- // Maximum number of registers that we will use to hold outputs.
77
- max_output_registers_ (max_output_registers),
78
- // Number of 8 bit inputs in the inputs register.
79
- num_inputs_per_register_ (num_inputs_per_register),
80
- // Number of inputs in each weight group.
81
- num_inputs_per_group_ (num_inputs_per_group),
82
- // Number of groups of inputs to be broadcast.
83
- num_input_groups_ (num_input_groups),
84
- // A series of functions to compute a partial result.
85
- partial_funcs_ (partial_funcs)
86
- {}
87
-
88
- // Computes a reshaped copy of the weight matrix w. If there are no
89
- // partial_funcs_, it does nothing.
61
+ struct IntSimdMatrix {
62
+ // Computes a reshaped copy of the weight matrix w.
90
63
void Init (const GENERIC_2D_ARRAY<int8_t >& w, std::vector<int8_t >& shaped_w) const ;
91
64
92
65
// Rounds the size up to a multiple of the input register size (in int8_t).
@@ -102,20 +75,11 @@ class IntSimdMatrix {
102
75
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
103
76
// u is imagined to have an extra element at the end with value 1, to
104
77
// implement the bias, but it doesn't actually have it.
105
- // Computes the base C++ implementation, if there are no partial_funcs_.
106
- // NOTE: The size of the input vector (u) must be padded using
107
- // RoundInputs above.
108
- // The input will be over-read to the extent of the padding. There are no
109
- // alignment requirements.
110
- void MatrixDotVector (const GENERIC_2D_ARRAY<int8_t >& w, const std::vector<int8_t >& shaped_w,
111
- const GenericVector<double >& scales, const int8_t * u,
112
- double * v) const ;
113
-
114
- static const IntSimdMatrix* intSimdMatrix;
115
- static const IntSimdMatrix IntSimdMatrixAVX2;
116
- static const IntSimdMatrix IntSimdMatrixSSE;
78
+ // Computes the base C++ implementation.
79
+ static void MatrixDotVector (const GENERIC_2D_ARRAY<int8_t >& w,
80
+ const GenericVector<double >& scales, const int8_t * u,
81
+ double * v);
117
82
118
- protected:
119
83
// Rounds the input up to a multiple of the given factor.
120
84
static int Roundup (int input, int factor) {
121
85
return (input + factor - 1 ) / factor * factor;
@@ -131,8 +95,23 @@ class IntSimdMatrix {
131
95
int num_inputs_per_group_;
132
96
// Number of groups of inputs to be broadcast.
133
97
int num_input_groups_;
134
- // A series of functions to compute a partial result.
135
- std::vector<PartialFunc> partial_funcs_;
98
+
99
+ // Computes matrix.vector v = Wu.
100
+ // u is of size W.dim2() - 1 and the output v is of size W.dim1().
101
+ // u is imagined to have an extra element at the end with value 1, to
102
+ // implement the bias, but it doesn't actually have it.
103
+ // Uses an optimized implementation with partial funcs.
104
+ // NOTE: The size of the input vector (u) must be padded using
105
+ // RoundInputs above.
106
+ // The input will be over-read to the extent of the padding. There are no
107
+ // alignment requirements.
108
+ typedef void (*MatrixDotVectorFunction)(int dim1, int dim2,
109
+ const int8_t * wi, const double * scales, const int8_t * u, double * v);
110
+ MatrixDotVectorFunction matrixDotVectorFunction;
111
+
112
+ static const IntSimdMatrix* intSimdMatrix;
113
+ static const IntSimdMatrix intSimdMatrixAVX2;
114
+ static const IntSimdMatrix intSimdMatrixSSE;
136
115
};
137
116
138
117
} // namespace tesseract
0 commit comments