Skip to content

Commit 9560639

Browse files
committed
Clean code for IntSimdMatrix
Signed-off-by: Stefan Weil <[email protected]>
1 parent 7fc7d28 commit 9560639

7 files changed

+155
-97
lines changed

src/arch/intsimdmatrix.cpp

+8-33
Original file line numberDiff line numberDiff line change
@@ -77,42 +77,17 @@ void IntSimdMatrix::Init(const GENERIC_2D_ARRAY<int8_t>& w,
7777
// u is imagined to have an extra element at the end with value 1, to
7878
// implement the bias, but it doesn't actually have it.
7979
void IntSimdMatrix::MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
80-
const std::vector<int8_t>& shaped_w,
8180
const GenericVector<double>& scales,
82-
const int8_t* u, double* v) const {
81+
const int8_t* u, double* v) {
8382
int num_out = w.dim1();
8483
int num_in = w.dim2() - 1;
85-
if (partial_funcs_.empty()) {
86-
// Base implementation.
87-
for (int i = 0; i < num_out; ++i) {
88-
const int8_t* wi = w[i];
89-
int total = 0;
90-
for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
91-
// Add in the bias and correct for integer values.
92-
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
93-
}
94-
} else {
95-
const int8_t* w_data = shaped_w.data();
96-
const double* scales_data = &scales[0];
97-
// Each call to a partial_func_ produces group_size outputs, except the
98-
// last one, which can produce less.
99-
int group_size = num_outputs_per_register_ * max_output_registers_;
100-
int rounded_num_in = Roundup(num_in, num_inputs_per_group_);
101-
int rounded_num_out = RoundOutputs(num_out);
102-
int output = 0;
103-
for (auto fn : partial_funcs_) {
104-
// The amount of w_data consumed by each call to fn.
105-
int w_step = (rounded_num_in + 1) * group_size;
106-
// Run with this group size, until it would produce too much output, then
107-
// switch to a smaller size.
108-
for (; output + group_size <= rounded_num_out; output += group_size) {
109-
(*fn)(w_data, scales_data, u, rounded_num_in, num_out - output, v);
110-
w_data += w_step;
111-
scales_data += group_size;
112-
v += group_size;
113-
}
114-
group_size /= 2;
115-
}
84+
// Base implementation.
85+
for (int i = 0; i < num_out; ++i) {
86+
const int8_t* wi = w[i];
87+
int total = 0;
88+
for (int j = 0; j < num_in; ++j) total += wi[j] * u[j];
89+
// Add in the bias and correct for integer values.
90+
v[i] = (static_cast<double>(total) / INT8_MAX + wi[num_in]) * scales[i];
11691
}
11792
}
11893

src/arch/intsimdmatrix.h

+23-44
Original file line numberDiff line numberDiff line change
@@ -58,35 +58,8 @@ namespace tesseract {
5858
// NOTE that, although the subclasses execute on different SIMD hardware, no
5959
// virtual methods are needed, as the constructor sets up everything that
6060
// is required to allow the base class implementation to do all the work.
61-
class IntSimdMatrix {
62-
public:
63-
// Function to compute part of a matrix.vector multiplication. The weights
64-
// are in a very specific order (see above) in w, which is multiplied by
65-
// u of length num_in, to produce output v after scaling the integer results
66-
// by the corresponding member of scales.
67-
// The amount of w and scales consumed is fixed and not available to the
68-
// caller. The number of outputs written to v will be at most num_out.
69-
typedef void (*PartialFunc)(const int8_t* w, const double* scales,
70-
const int8_t* u, int num_in, int num_out,
71-
double* v);
72-
73-
IntSimdMatrix(int num_outputs_per_register, int max_output_registers, int num_inputs_per_register, int num_inputs_per_group, int num_input_groups, std::vector<PartialFunc> partial_funcs) :
74-
// Number of 32 bit outputs held in each register.
75-
num_outputs_per_register_(num_outputs_per_register),
76-
// Maximum number of registers that we will use to hold outputs.
77-
max_output_registers_(max_output_registers),
78-
// Number of 8 bit inputs in the inputs register.
79-
num_inputs_per_register_(num_inputs_per_register),
80-
// Number of inputs in each weight group.
81-
num_inputs_per_group_(num_inputs_per_group),
82-
// Number of groups of inputs to be broadcast.
83-
num_input_groups_(num_input_groups),
84-
// A series of functions to compute a partial result.
85-
partial_funcs_(partial_funcs)
86-
{}
87-
88-
// Computes a reshaped copy of the weight matrix w. If there are no
89-
// partial_funcs_, it does nothing.
61+
struct IntSimdMatrix {
62+
// Computes a reshaped copy of the weight matrix w.
9063
void Init(const GENERIC_2D_ARRAY<int8_t>& w, std::vector<int8_t>& shaped_w) const;
9164

9265
// Rounds the size up to a multiple of the input register size (in int8_t).
@@ -102,20 +75,11 @@ class IntSimdMatrix {
10275
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
10376
// u is imagined to have an extra element at the end with value 1, to
10477
// implement the bias, but it doesn't actually have it.
105-
// Computes the base C++ implementation, if there are no partial_funcs_.
106-
// NOTE: The size of the input vector (u) must be padded using
107-
// RoundInputs above.
108-
// The input will be over-read to the extent of the padding. There are no
109-
// alignment requirements.
110-
void MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w, const std::vector<int8_t>& shaped_w,
111-
const GenericVector<double>& scales, const int8_t* u,
112-
double* v) const;
113-
114-
static const IntSimdMatrix* intSimdMatrix;
115-
static const IntSimdMatrix IntSimdMatrixAVX2;
116-
static const IntSimdMatrix IntSimdMatrixSSE;
78+
// Computes the base C++ implementation.
79+
static void MatrixDotVector(const GENERIC_2D_ARRAY<int8_t>& w,
80+
const GenericVector<double>& scales, const int8_t* u,
81+
double* v);
11782

118-
protected:
11983
// Rounds the input up to a multiple of the given factor.
12084
static int Roundup(int input, int factor) {
12185
return (input + factor - 1) / factor * factor;
@@ -131,8 +95,23 @@ class IntSimdMatrix {
13195
int num_inputs_per_group_;
13296
// Number of groups of inputs to be broadcast.
13397
int num_input_groups_;
134-
// A series of functions to compute a partial result.
135-
std::vector<PartialFunc> partial_funcs_;
98+
99+
// Computes matrix.vector v = Wu.
100+
// u is of size W.dim2() - 1 and the output v is of size W.dim1().
101+
// u is imagined to have an extra element at the end with value 1, to
102+
// implement the bias, but it doesn't actually have it.
103+
// Uses an optimized implementation with partial funcs.
104+
// NOTE: The size of the input vector (u) must be padded using
105+
// RoundInputs above.
106+
// The input will be over-read to the extent of the padding. There are no
107+
// alignment requirements.
108+
typedef void (*MatrixDotVectorFunction)(int dim1, int dim2,
109+
const int8_t* wi, const double* scales, const int8_t* u, double* v);
110+
MatrixDotVectorFunction matrixDotVectorFunction;
111+
112+
static const IntSimdMatrix* intSimdMatrix;
113+
static const IntSimdMatrix intSimdMatrixAVX2;
114+
static const IntSimdMatrix intSimdMatrixSSE;
136115
};
137116

138117
} // namespace tesseract

src/arch/intsimdmatrixavx2.cpp

+73-3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ constexpr int kNumInputsPerGroup = 4;
4040
// Number of groups of inputs to be broadcast.
4141
constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
4242

43+
// Functions to compute part of a matrix.vector multiplication. The weights
44+
// are in a very specific order (see above) in w, which is multiplied by
45+
// u of length num_in, to produce output v after scaling the integer results
46+
// by the corresponding member of scales.
47+
// The amount of w and scales consumed is fixed and not available to the
48+
// caller. The number of outputs written to v will be at most num_out.
49+
4350
// Computes one set of 4x8 products of inputs and weights, adding to result.
4451
// Horizontally adds 4 adjacent results, making 8x32-bit results.
4552
// rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
@@ -269,8 +276,71 @@ static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
269276
ExtractResults(result0, shift_id, wi, scales, num_out, v);
270277
}
271278

272-
const IntSimdMatrix IntSimdMatrix::IntSimdMatrixAVX2 =
273-
IntSimdMatrix(kNumOutputsPerRegister, kMaxOutputRegisters, kNumInputsPerRegister, kNumInputsPerGroup, kNumInputGroups, {PartialMatrixDotVector64, PartialMatrixDotVector32,
274-
PartialMatrixDotVector16, PartialMatrixDotVector8});
279+
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
280+
const double* scales, const int8_t* u, double* v) {
281+
const int num_out = dim1;
282+
const int num_in = dim2 - 1;
283+
// Each call to a partial_func_ produces group_size outputs, except the
284+
// last one, which can produce less.
285+
const int rounded_num_in =
286+
IntSimdMatrix::Roundup(num_in, kNumInputsPerGroup);
287+
const int rounded_num_out =
288+
IntSimdMatrix::Roundup(num_out, kNumOutputsPerRegister);
289+
int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
290+
int output = 0;
291+
292+
int w_step = (rounded_num_in + 1) * group_size;
293+
294+
// Run with this group size, until it would produce too much output, then
295+
// switch to a smaller size.
296+
for (; output + group_size <= rounded_num_out; output += group_size) {
297+
PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
298+
wi += w_step;
299+
scales += group_size;
300+
v += group_size;
301+
}
302+
group_size /= 2;
303+
w_step /= 2;
304+
305+
for (; output + group_size <= rounded_num_out; output += group_size) {
306+
PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
307+
wi += w_step;
308+
scales += group_size;
309+
v += group_size;
310+
}
311+
group_size /= 2;
312+
w_step /= 2;
313+
314+
for (; output + group_size <= rounded_num_out; output += group_size) {
315+
PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
316+
wi += w_step;
317+
scales += group_size;
318+
v += group_size;
319+
}
320+
group_size /= 2;
321+
w_step /= 2;
322+
323+
for (; output + group_size <= rounded_num_out; output += group_size) {
324+
PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
325+
wi += w_step;
326+
scales += group_size;
327+
v += group_size;
328+
}
329+
}
330+
331+
const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
332+
// Number of 32 bit outputs held in each register.
333+
kNumOutputsPerRegister,
334+
// Maximum number of registers that we will use to hold outputs.
335+
kMaxOutputRegisters,
336+
// Number of 8 bit inputs in the inputs register.
337+
kNumInputsPerRegister,
338+
// Number of inputs in each weight group.
339+
kNumInputsPerGroup,
340+
// Number of groups of inputs to be broadcast.
341+
kNumInputGroups,
342+
// Function.
343+
matrixDotVector
344+
};
275345

276346
} // namespace tesseract.

src/arch/intsimdmatrixsse.cpp

+27-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,32 @@ static void PartialMatrixDotVector1(const int8_t* wi, const double* scales,
3535
*v = (total / INT8_MAX + wi[num_in]) * *scales;
3636
}
3737

38-
const IntSimdMatrix IntSimdMatrix::IntSimdMatrixSSE =
39-
IntSimdMatrix(1, 1, 1, 1, 1, {PartialMatrixDotVector1});
38+
static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
39+
const double* scales, const int8_t* u, double* v) {
40+
const int num_out = dim1;
41+
const int num_in = dim2 - 1;
42+
int output = 0;
43+
44+
for (; output + 1 <= num_out; output += 1) {
45+
PartialMatrixDotVector1(wi, scales, u, num_in, num_out - output, v);
46+
wi += dim2;
47+
scales += 1;
48+
v += 1;
49+
}
50+
}
51+
52+
const IntSimdMatrix IntSimdMatrix::intSimdMatrixSSE = {
53+
// Number of 32 bit outputs held in each register.
54+
1,
55+
// Maximum number of registers that we will use to hold outputs.
56+
1,
57+
// Number of 8 bit inputs in the inputs register.
58+
1,
59+
// Number of inputs in each weight group.
60+
1,
61+
// Number of groups of inputs to be broadcast.
62+
1,
63+
matrixDotVector
64+
};
4065

4166
} // namespace tesseract.

src/arch/simddetect.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ SIMDDetect::SIMDDetect() {
128128
#if defined(AVX)
129129
} else if (avx_available_) {
130130
// AVX detected.
131-
SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2);
131+
SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixAVX2);
132132
#endif
133133
#if defined(SSE4_1)
134134
} else if (sse_available_) {
135135
// SSE detected.
136-
SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE);
136+
SetDotProduct(DotProductSSE, &IntSimdMatrix::intSimdMatrixSSE);
137137
#endif
138138
}
139139
}
@@ -155,13 +155,13 @@ void SIMDDetect::Update() {
155155
#if defined(AVX)
156156
} else if (!strcmp(dotproduct.string(), "avx")) {
157157
// AVX selected by config variable.
158-
SetDotProduct(DotProductAVX, &IntSimdMatrix::IntSimdMatrixAVX2);
158+
SetDotProduct(DotProductAVX, &IntSimdMatrix::intSimdMatrixAVX2);
159159
dotproduct_method = "avx";
160160
#endif
161161
#if defined(SSE4_1)
162162
} else if (!strcmp(dotproduct.string(), "sse")) {
163163
// SSE selected by config variable.
164-
SetDotProduct(DotProductSSE, &IntSimdMatrix::IntSimdMatrixSSE);
164+
SetDotProduct(DotProductSSE, &IntSimdMatrix::intSimdMatrixSSE);
165165
dotproduct_method = "sse";
166166
#endif
167167
} else {

src/lstm/weightmatrix.cpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,9 @@ void WeightMatrix::ConvertToInt() {
143143
}
144144
wf_.Resize(1, 1, 0.0);
145145
int_mode_ = true;
146-
if (IntSimdMatrix::intSimdMatrix)
146+
if (IntSimdMatrix::intSimdMatrix) {
147147
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
148+
}
148149
}
149150

150151
// Allocates any needed memory for running Backward, and zeroes the deltas,
@@ -196,8 +197,9 @@ bool WeightMatrix::DeSerialize(bool training, TFile* fp) {
196197
if (int_mode_) {
197198
if (!wi_.DeSerialize(fp)) return false;
198199
if (!scales_.DeSerialize(fp)) return false;
199-
if (IntSimdMatrix::intSimdMatrix)
200+
if (IntSimdMatrix::intSimdMatrix) {
200201
IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
202+
}
201203
} else {
202204
if (!wf_.DeSerialize(fp)) return false;
203205
if (training) {
@@ -245,7 +247,12 @@ void WeightMatrix::MatrixDotVector(const double* u, double* v) const {
245247

246248
void WeightMatrix::MatrixDotVector(const int8_t* u, double* v) const {
247249
assert(int_mode_);
248-
IntSimdMatrix::intSimdMatrix->MatrixDotVector(wi_, shaped_w_, scales_, u, v);
250+
if (IntSimdMatrix::intSimdMatrix) {
251+
IntSimdMatrix::intSimdMatrix->matrixDotVectorFunction(
252+
wi_.dim1(), wi_.dim2(), &shaped_w_[0], &scales_[0], u, v);
253+
} else {
254+
IntSimdMatrix::MatrixDotVector(wi_, scales_, u, v);
255+
}
249256
}
250257

251258
// MatrixDotVector for peep weights, MultiplyAccumulate adds the

unittest/intsimdmatrix_test.cc

+10-8
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
namespace tesseract {
2626
namespace {
2727

28-
static const IntSimdMatrix IntSimdMatrixNative = IntSimdMatrix(1, 1, 1, 1, 1, {});
29-
3028
class IntSimdMatrixTest : public ::testing::Test {
3129
protected:
3230
// Makes a random weights matrix of the given size.
@@ -65,12 +63,16 @@ class IntSimdMatrixTest : public ::testing::Test {
6563
std::vector<int8_t> u = RandomVector(num_in, matrix);
6664
GenericVector<double> scales = RandomScales(num_out);
6765
std::vector<double> base_result(num_out);
68-
std::vector<int8_t> dummy;
69-
IntSimdMatrixNative.MatrixDotVector(w, dummy, scales, u.data(), base_result.data());
66+
IntSimdMatrix::MatrixDotVector(w, scales, u.data(), base_result.data());
7067
std::vector<double> test_result(num_out);
7168
std::vector<int8_t> shaped_wi;
7269
matrix.Init(w, shaped_wi);
73-
matrix.MatrixDotVector(w, shaped_wi, scales, u.data(), test_result.data());
70+
if (matrix.matrixDotVectorFunction) {
71+
matrix.matrixDotVectorFunction(w.dim1(), w.dim2(), &shaped_wi[0],
72+
&scales[0], &u[0], &test_result[0]);
73+
} else {
74+
IntSimdMatrix::MatrixDotVector(w, scales, u.data(), test_result.data());
75+
}
7476
for (int i = 0; i < num_out; ++i) {
7577
EXPECT_FLOAT_EQ(base_result[i], test_result[i]) << "i=" << i;
7678
total += base_result[i];
@@ -86,7 +88,7 @@ class IntSimdMatrixTest : public ::testing::Test {
8688

8789
// Test the C++ implementation without SIMD.
8890
TEST_F(IntSimdMatrixTest, C) {
89-
static const IntSimdMatrix matrix(1, 1, 1, 1, 1, {});
91+
static const IntSimdMatrix matrix = {1, 1, 1, 1, 1, nullptr};
9092
ExpectEqualResults(matrix);
9193
}
9294

@@ -99,7 +101,7 @@ TEST_F(IntSimdMatrixTest, SSE) {
99101
tprintf("No SSE found! Not tested!");
100102
return;
101103
}
102-
ExpectEqualResults(IntSimdMatrix::IntSimdMatrixSSE);
104+
ExpectEqualResults(IntSimdMatrix::intSimdMatrixSSE);
103105
#else
104106
tprintf("SSE unsupported! Not tested!");
105107
#endif
@@ -114,7 +116,7 @@ TEST_F(IntSimdMatrixTest, AVX2) {
114116
tprintf("No AVX2 found! Not tested!");
115117
return;
116118
}
117-
ExpectEqualResults(IntSimdMatrix::IntSimdMatrixAVX2);
119+
ExpectEqualResults(IntSimdMatrix::intSimdMatrixAVX2);
118120
#else
119121
tprintf("AVX2 unsupported! Not tested!");
120122
#endif

0 commit comments

Comments
 (0)