Skip to content

Commit c79d613

Browse files
committed
Replace ASSERT_HOST by assert
Signed-off-by: Stefan Weil <[email protected]>
1 parent f75b2c1 commit c79d613

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

src/lstm/weightmatrix.cpp

+16-15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "weightmatrix.h"
1919

20+
#include <cassert> // for assert
2021
#include "intsimdmatrix.h"
2122
#include "simddetect.h" // for DotProduct
2223
#include "statistc.h"
@@ -238,21 +239,21 @@ bool WeightMatrix::DeSerializeOld(bool training, TFile* fp) {
238239
// implement the bias, but it doesn't actually have it.
239240
// Asserts that the call matches what we have.
240241
void WeightMatrix::MatrixDotVector(const double* u, double* v) const {
241-
ASSERT_HOST(!int_mode_);
242+
assert(!int_mode_);
242243
MatrixDotVectorInternal(wf_, true, false, u, v);
243244
}
244245

245246
void WeightMatrix::MatrixDotVector(const int8_t* u, double* v) const {
246-
ASSERT_HOST(int_mode_);
247-
ASSERT_HOST(multiplier_ != nullptr);
247+
assert(int_mode_);
248+
assert(multiplier_ != nullptr);
248249
multiplier_->MatrixDotVector(wi_, scales_, u, v);
249250
}
250251

251252
// MatrixDotVector for peep weights, MultiplyAccumulate adds the
252253
// component-wise products of *this[0] and v to inout.
253254
void WeightMatrix::MultiplyAccumulate(const double* v, double* inout) {
254-
ASSERT_HOST(!int_mode_);
255-
ASSERT_HOST(wf_.dim1() == 1);
255+
assert(!int_mode_);
256+
assert(wf_.dim1() == 1);
256257
int n = wf_.dim2();
257258
const double* u = wf_[0];
258259
for (int i = 0; i < n; ++i) {
@@ -265,7 +266,7 @@ void WeightMatrix::MultiplyAccumulate(const double* v, double* inout) {
265266
// The last result is discarded, as v is assumed to have an imaginary
266267
// last value of 1, as with MatrixDotVector.
267268
void WeightMatrix::VectorDotMatrix(const double* u, double* v) const {
268-
ASSERT_HOST(!int_mode_);
269+
assert(!int_mode_);
269270
MatrixDotVectorInternal(wf_t_, false, true, u, v);
270271
}
271272

@@ -277,14 +278,14 @@ void WeightMatrix::VectorDotMatrix(const double* u, double* v) const {
277278
void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
278279
const TransposedArray& v,
279280
bool in_parallel) {
280-
ASSERT_HOST(!int_mode_);
281+
assert(!int_mode_);
281282
int num_outputs = dw_.dim1();
282-
ASSERT_HOST(u.dim1() == num_outputs);
283-
ASSERT_HOST(u.dim2() == v.dim2());
283+
assert(u.dim1() == num_outputs);
284+
assert(u.dim2() == v.dim2());
284285
int num_inputs = dw_.dim2() - 1;
285286
int num_samples = u.dim2();
286287
// v is missing the last element in dim1.
287-
ASSERT_HOST(v.dim1() == num_inputs);
288+
assert(v.dim1() == num_inputs);
288289
#ifdef _OPENMP
289290
#pragma omp parallel for num_threads(4) if (in_parallel)
290291
#endif
@@ -306,7 +307,7 @@ void WeightMatrix::SumOuterTransposed(const TransposedArray& u,
306307
// use_adam_ is true.
307308
void WeightMatrix::Update(double learning_rate, double momentum,
308309
double adam_beta, int num_samples) {
309-
ASSERT_HOST(!int_mode_);
310+
assert(!int_mode_);
310311
if (use_adam_ && num_samples > 0 && num_samples < kAdamCorrectionIterations) {
311312
learning_rate *= sqrt(1.0 - pow(adam_beta, num_samples));
312313
learning_rate /= 1.0 - pow(momentum, num_samples);
@@ -328,8 +329,8 @@ void WeightMatrix::Update(double learning_rate, double momentum,
328329

329330
// Adds the dw_ in other to the dw_ is *this.
330331
void WeightMatrix::AddDeltas(const WeightMatrix& other) {
331-
ASSERT_HOST(dw_.dim1() == other.dw_.dim1());
332-
ASSERT_HOST(dw_.dim2() == other.dw_.dim2());
332+
assert(dw_.dim1() == other.dw_.dim1());
333+
assert(dw_.dim2() == other.dw_.dim2());
333334
dw_ += other.dw_;
334335
}
335336

@@ -340,8 +341,8 @@ void WeightMatrix::CountAlternators(const WeightMatrix& other, double* same,
340341
double* changed) const {
341342
int num_outputs = updates_.dim1();
342343
int num_inputs = updates_.dim2();
343-
ASSERT_HOST(num_outputs == other.updates_.dim1());
344-
ASSERT_HOST(num_inputs == other.updates_.dim2());
344+
assert(num_outputs == other.updates_.dim1());
345+
assert(num_inputs == other.updates_.dim2());
345346
for (int i = 0; i < num_outputs; ++i) {
346347
const double* this_i = updates_[i];
347348
const double* other_i = other.updates_[i];

0 commit comments

Comments
 (0)