18
18
#include < memory>
19
19
#include " genericvector.h"
20
20
#include " include_gunit.h"
21
- #include " intsimdmatrixavx2.h"
22
- #include " intsimdmatrixsse.h"
23
21
#include " matrix.h"
24
22
#include " simddetect.h"
25
23
#include " tprintf.h"
@@ -56,21 +54,21 @@ class IntSimdMatrixTest : public ::testing::Test {
56
54
}
57
55
return v;
58
56
}
59
- // Tests a range of sizes and compares the results against the base_ version.
60
- void ExpectEqualResults (const IntSimdMatrix* matrix) {
57
+ // Tests a range of sizes and compares the results against the generic version.
58
+ void ExpectEqualResults (const IntSimdMatrix& matrix) {
61
59
double total = 0.0 ;
62
60
for (int num_out = 1 ; num_out < 130 ; ++num_out) {
63
61
for (int num_in = 1 ; num_in < 130 ; ++num_in) {
64
62
GENERIC_2D_ARRAY<int8_t > w = InitRandom (num_out, num_in + 1 );
65
- std::vector<int8_t > u = RandomVector (num_in, * matrix);
63
+ std::vector<int8_t > u = RandomVector (num_in, matrix);
66
64
GenericVector<double > scales = RandomScales (num_out);
67
65
std::vector<double > base_result (num_out);
68
66
std::vector<int8_t > dummy;
69
- base_ .MatrixDotVector (w, dummy, scales, u.data (), base_result.data ());
67
+ IntSimdMatrix::IntSimdMatrixNative .MatrixDotVector (w, dummy, scales, u.data (), base_result.data ());
70
68
std::vector<double > test_result (num_out);
71
69
std::vector<int8_t > shaped_wi;
72
- matrix-> Init (w, shaped_wi);
73
- matrix-> MatrixDotVector (w, shaped_wi, scales, u.data (), test_result.data ());
70
+ matrix. Init (w, shaped_wi);
71
+ matrix. MatrixDotVector (w, shaped_wi, scales, u.data (), test_result.data ());
74
72
for (int i = 0 ; i < num_out; ++i) {
75
73
EXPECT_FLOAT_EQ (base_result[i], test_result[i]) << " i=" << i;
76
74
total += base_result[i];
@@ -82,13 +80,12 @@ class IntSimdMatrixTest : public ::testing::Test {
82
80
}
83
81
84
82
TRand random_;
85
- IntSimdMatrix base_ = IntSimdMatrix(1 , 1 , 1 , 1 , 1 , {});
86
83
};
87
84
88
85
// Test the C++ implementation without SIMD.
89
86
TEST_F (IntSimdMatrixTest, C) {
90
- std::unique_ptr< IntSimdMatrix> matrix (new IntSimdMatrix () );
91
- ExpectEqualResults (matrix. get () );
87
+ static const IntSimdMatrix matrix (1 , 1 , 1 , 1 , 1 , {} );
88
+ ExpectEqualResults (matrix);
92
89
}
93
90
94
91
// Tests that the SSE implementation gets the same result as the vanilla.
@@ -99,8 +96,7 @@ TEST_F(IntSimdMatrixTest, SSE) {
99
96
tprintf (" No SSE found! Not Tested!" );
100
97
return ;
101
98
}
102
- std::unique_ptr<IntSimdMatrix> matrix (new IntSimdMatrixSSE ());
103
- ExpectEqualResults (matrix.get ());
99
+ ExpectEqualResults (IntSimdMatrix::IntSimdMatrixSSE);
104
100
}
105
101
106
102
// Tests that the AVX2 implementation gets the same result as the vanilla.
@@ -111,8 +107,7 @@ TEST_F(IntSimdMatrixTest, AVX2) {
111
107
tprintf (" No AVX2 found! Not Tested!" );
112
108
return ;
113
109
}
114
- std::unique_ptr<IntSimdMatrix> matrix (new IntSimdMatrixAVX2 ());
115
- ExpectEqualResults (matrix.get ());
110
+ ExpectEqualResults (IntSimdMatrix::IntSimdMatrixAVX2);
116
111
}
117
112
118
113
} // namespace
0 commit comments