Skip to content

Commit f0a4d04

Browse files
committed
Add config variable for selection of dot product function
All also a C++ implementation with more aggressive compiler options which is optimized for the CPU where the software was built. It is now possible to select the function used for the dot product with -c dotproduct=FUNCTION where FUNCTION can be one of those values: * auto selection based on detected hardware (default) * generic C++ code with default compiler options * native C++ code optimized for build host * avx optimized code for AVX * sse optimized code for SSE Signed-off-by: Stefan Weil <[email protected]>
1 parent b527b37 commit f0a4d04

9 files changed

+160
-32
lines changed

src/api/Makefile.am

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ libtesseract_la_LIBADD = \
4848
../classify/libtesseract_classify.la \
4949
../dict/libtesseract_dict.la \
5050
../arch/libtesseract_arch.la \
51+
../arch/libtesseract_native.la \
5152
../arch/libtesseract_avx.la \
5253
../arch/libtesseract_avx2.la \
5354
../arch/libtesseract_sse.la \

src/api/tesseractmain.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
/**********************************************************************
2-
* File: tesseractmain.cpp (Formerly tessedit.c)
2+
* File: tesseractmain.cpp
33
* Description: Main program for merge of tess and editor.
44
* Author: Ray Smith
5-
* Created: Tue Jan 07 15:21:46 GMT 1992
65
*
76
* (C) Copyright 1992, Hewlett-Packard Ltd.
87
** Licensed under the Apache License, Version 2.0 (the "License");
@@ -585,6 +584,9 @@ int main(int argc, char** argv) {
585584

586585
SetVariablesFromCLArgs(&api, argc, argv);
587586

587+
// SIMD settings might be overridden by config variable.
588+
tesseract::SIMDDetect::Update();
589+
588590
if (list_langs) {
589591
PrintLangsList(&api);
590592
return EXIT_SUCCESS;

src/arch/Makefile.am

+7-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ endif
1010

1111
pkginclude_HEADERS =
1212

13-
noinst_HEADERS = dotproductavx.h dotproductsse.h
13+
noinst_HEADERS = dotproduct.h dotproductavx.h dotproductsse.h
1414
noinst_HEADERS += intsimdmatrix.h intsimdmatrixavx2.h intsimdmatrixsse.h
1515
noinst_HEADERS += simddetect.h
1616

17-
noinst_LTLIBRARIES = libtesseract_avx.la libtesseract_avx2.la libtesseract_sse.la
17+
noinst_LTLIBRARIES = libtesseract_native.la
18+
noinst_LTLIBRARIES += libtesseract_avx.la libtesseract_avx2.la
19+
noinst_LTLIBRARIES += libtesseract_sse.la
1820
noinst_LTLIBRARIES += libtesseract_arch.la
1921

2022
if AVX_OPT
@@ -27,6 +29,9 @@ if SSE41_OPT
2729
libtesseract_sse_la_CXXFLAGS = -ffast-math -msse4.1
2830
endif
2931

32+
libtesseract_native_la_CXXFLAGS = -O3 -ffast-math -march=native -mtune=native
33+
libtesseract_native_la_SOURCES = dotproduct.cpp
34+
3035
libtesseract_arch_la_SOURCES = intsimdmatrix.cpp simddetect.cpp
3136

3237
libtesseract_avx_la_SOURCES = dotproductavx.cpp

src/arch/dotproduct.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
///////////////////////////////////////////////////////////////////////
2+
// File: dotproduct.h
3+
// Description: Native dot product function.
4+
//
5+
// (C) Copyright 2018, Google Inc.
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
///////////////////////////////////////////////////////////////////////
16+
17+
#include "dotproduct.h"
18+
19+
namespace tesseract {
20+
21+
// Computes and returns the dot product of the two n-vectors u and v.
22+
double DotProductNative(const double* u, const double* v, int n) {
23+
double total = 0.0;
24+
for (int k = 0; k < n; ++k) total += u[k] * v[k];
25+
return total;
26+
}
27+
28+
} // namespace tesseract

src/arch/dotproduct.h

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
///////////////////////////////////////////////////////////////////////
2+
// File: dotproduct.h
3+
// Description: Native dot product function.
4+
//
5+
// (C) Copyright 2018, Google Inc.
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
///////////////////////////////////////////////////////////////////////
16+
17+
#ifndef TESSERACT_ARCH_DOTPRODUCT_H_
18+
#define TESSERACT_ARCH_DOTPRODUCT_H_
19+
20+
namespace tesseract {
21+
22+
// Computes and returns the dot product of the n-vectors u and v.
23+
double DotProductNative(const double* u, const double* v, int n);
24+
25+
} // namespace tesseract.
26+
27+
#endif // TESSERACT_ARCH_DOTPRODUCT_H_

src/arch/simddetect.cpp

+85
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
///////////////////////////////////////////////////////////////////////
1717

1818
#include "simddetect.h"
19+
#include "dotproduct.h"
20+
#include "dotproductavx.h"
21+
#include "dotproductsse.h"
22+
#include "params.h" // for STRING_VAR
23+
#include "tprintf.h" // for tprintf
1924

2025
#undef X86_BUILD
2126
#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
@@ -34,6 +39,21 @@
3439

3540
namespace tesseract {
3641

42+
// Computes and returns the dot product of the two n-vectors u and v.
43+
// Note: because the order of addition is different among the different dot
44+
// product functions, the results can (and do) vary slightly (although they
45+
// agree to within about 4e-15). This produces different results when running
46+
// training, despite all random inputs being precisely equal.
47+
// To get consistent results, use just one of these dot product functions.
48+
// On a test multi-layer network, serial is 57% slower than SSE, and AVX
49+
// is about 8% faster than SSE. This suggests that the time is memory
50+
// bandwidth constrained and could benefit from holding the reused vector
51+
// in AVX registers.
52+
DotProductFunction DotProduct;
53+
54+
static STRING_VAR(dotproduct, "auto",
55+
"Function used for calculation of dot product");
56+
3757
SIMDDetect SIMDDetect::detector;
3858

3959
// If true, then AVX has been detected.
@@ -44,12 +64,26 @@ bool SIMDDetect::avx512BW_available_;
4464
// If true, then SSe4.1 has been detected.
4565
bool SIMDDetect::sse_available_;
4666

67+
// Computes and returns the dot product of the two n-vectors u and v.
68+
static double DotProductGeneric(const double* u, const double* v, int n) {
69+
double total = 0.0;
70+
for (int k = 0; k < n; ++k) total += u[k] * v[k];
71+
return total;
72+
}
73+
74+
static void SetDotProduct(DotProductFunction function) {
75+
DotProduct = function;
76+
}
77+
4778
// Constructor.
4879
// Tests the architecture in a system-dependent way to detect AVX, SSE and
4980
// any other available SIMD equipment.
5081
// __GNUC__ is also defined by compilers that include GNU extensions such as
5182
// clang.
5283
SIMDDetect::SIMDDetect() {
84+
// The fallback is a generic dot product calculation.
85+
SetDotProduct(DotProductGeneric);
86+
5387
#if defined(X86_BUILD)
5488
# if defined(__GNUC__)
5589
unsigned int eax, ebx, ecx, edx;
@@ -80,6 +114,57 @@ SIMDDetect::SIMDDetect() {
80114
# error "I don't know how to test for SIMD with this compiler"
81115
# endif
82116
#endif // X86_BUILD
117+
118+
#if defined(X86_BUILD)
119+
// Select code for calculation of dot product based on autodetection.
120+
if (avx_available_) {
121+
// AVX detected.
122+
SetDotProduct(DotProductAVX);
123+
} else if (sse_available_) {
124+
// SSE detected.
125+
SetDotProduct(DotProductSSE);
126+
}
127+
#endif // X86_BUILD
128+
}
129+
130+
void SIMDDetect::Update() {
131+
// Select code for calculation of dot product based on the
132+
// value of the config variable if that value is not empty.
133+
const char* dotproduct_method = "generic";
134+
if (!strcmp(dotproduct.string(), "auto")) {
135+
// Automatic detection. Nothing to be done.
136+
} else if (!strcmp(dotproduct.string(), "generic")) {
137+
// Generic code selected by config variable.
138+
SetDotProduct(DotProductGeneric);
139+
dotproduct_method = "generic";
140+
} else if (!strcmp(dotproduct.string(), "native")) {
141+
// Native optimized code selected by config variable.
142+
SetDotProduct(DotProductNative);
143+
dotproduct_method = "native";
144+
}
145+
#if defined(X86_BUILD)
146+
else if (!strcmp(dotproduct.string(), "avx")) {
147+
// AVX selected by config variable.
148+
SetDotProduct(DotProductAVX);
149+
dotproduct_method = "avx";
150+
} else if (!strcmp(dotproduct.string(), "sse")) {
151+
// SSE selected by config variable.
152+
SetDotProduct(DotProductSSE);
153+
dotproduct_method = "sse";
154+
}
155+
#endif // X86_BUILD
156+
else {
157+
// Unsupported value of config variable.
158+
tprintf("Warning, ignoring unsupported config variable value: dotproduct=%s\n",
159+
dotproduct.string());
160+
tprintf("Support values for dotproduct: auto generic native"
161+
#if defined(X86_BUILD)
162+
" avx sse"
163+
#endif // X86_BUILD
164+
".\n");
165+
}
166+
167+
dotproduct.set_value(dotproduct_method);
83168
}
84169

85170
} // namespace tesseract

src/arch/simddetect.h

+7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121

2222
namespace tesseract {
2323

24+
// Function pointer for best calculation of dot product.
25+
typedef double (*DotProductFunction)(const double* u, const double* v, int n);
26+
extern DotProductFunction DotProduct;
27+
2428
// Architecture detector. Add code here to detect any other architectures for
2529
// SIMD-based faster dot product functions. Intended to be a single static
2630
// object, but it does no real harm to have more than one.
@@ -41,6 +45,9 @@ class SIMDDetect {
4145
// Returns true if SSE4.1 is available on this system.
4246
static inline bool IsSSEAvailable() { return detector.sse_available_; }
4347

48+
// Update settings after config variable was set.
49+
static void Update();
50+
4451
private:
4552
// Constructor, must set all static member variables.
4653
SIMDDetect();

src/lstm/weightmatrix.cpp

+1-27
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// File: weightmatrix.cpp
33
// Description: Hides distinction between float/int implementations.
44
// Author: Ray Smith
5-
// Created: Tue Jun 17 11:46:20 PST 2014
65
//
76
// (C) Copyright 2014, Google Inc.
87
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,10 +17,8 @@
1817

1918
#include "weightmatrix.h"
2019

21-
#include "dotproductavx.h"
22-
#include "dotproductsse.h"
2320
#include "intsimdmatrix.h"
24-
#include "simddetect.h"
21+
#include "simddetect.h" // for DotProduct
2522
#include "statistc.h"
2623
#include "tprintf.h"
2724

@@ -38,29 +35,6 @@ const int kAdamCorrectionIterations = 200000;
3835
// Epsilon in Adam to prevent division by zero.
3936
const double kAdamEpsilon = 1e-8;
4037

41-
// Computes and returns the dot product of the two n-vectors u and v.
42-
static inline double DotProduct(const double* u, const double* v, int n) {
43-
// Note: because the order of addition is different among the 3 DotProduct
44-
// functions, the results can (and do) vary slightly (although they agree
45-
// to within about 4e-15). This produces different results when running
46-
// training, despite all random inputs being precisely equal.
47-
// To get consistent results, use just one of these DotProduct functions.
48-
// On a test multi-layer network, serial is 57% slower than sse, and avx
49-
// is about 8% faster than sse. This suggests that the time is memory
50-
// bandwidth constrained and could benefit from holding the reused vector
51-
// in AVX registers.
52-
53-
if (SIMDDetect::IsAVXAvailable())
54-
return DotProductAVX(u, v, n);
55-
56-
if (SIMDDetect::IsSSEAvailable())
57-
return DotProductSSE(u, v, n);
58-
59-
double total = 0.0;
60-
for (int k = 0; k < n; ++k) total += u[k] * v[k];
61-
return total;
62-
}
63-
6438
// Computes matrix.vector v = Wu.
6539
// u is of size W.dim2() - add_bias_fwd and the output v is of size
6640
// W.dim1() - skip_bias_back.

src/lstm/weightmatrix.h

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// File: weightmatrix.h
33
// Description: Hides distinction between float/int implementations.
44
// Author: Ray Smith
5-
// Created: Tue Jun 17 09:05:39 PST 2014
65
//
76
// (C) Copyright 2014, Google Inc.
87
// Licensed under the Apache License, Version 2.0 (the "License");

0 commit comments

Comments
 (0)