Skip to content

Commit 7cf2e2a

Browse files
committed
Overload method ForwardTimeStep (CID 1385636 Explicit null dereferenced)
This avoids NULL parameters and fixes a warning from Coverity Scan. Signed-off-by: Stefan Weil <[email protected]>
1 parent 437bf85 commit 7cf2e2a

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

lstm/fullyconnected.cpp

+19-14
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,12 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
147147
int thread_id = 0;
148148
#endif
149149
double* temp_line = temp_lines[thread_id];
150-
const double* d_input = nullptr;
151-
const int8_t* i_input = nullptr;
152150
if (input.int_mode()) {
153-
i_input = input.i(t);
151+
ForwardTimeStep(input.i(t), t, temp_line);
154152
} else {
155153
input.ReadTimeStep(t, curr_input[thread_id]);
156-
d_input = curr_input[thread_id];
154+
ForwardTimeStep(curr_input[thread_id], t, temp_line);
157155
}
158-
ForwardTimeStep(d_input, i_input, t, temp_line);
159156
output->WriteTimeStep(t, temp_line);
160157
if (IsTraining() && type_ != NT_SOFTMAX) {
161158
acts_.CopyTimeStepFrom(t, *output, t);
@@ -188,15 +185,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
188185
}
189186
}
190187

191-
void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_input,
192-
int t, double* output_line) {
193-
// input is copied to source_ line-by-line for cache coherency.
194-
if (IsTraining() && external_source_ == nullptr && d_input != nullptr)
195-
source_t_.WriteStrided(t, d_input);
196-
if (d_input != nullptr)
197-
weights_.MatrixDotVector(d_input, output_line);
198-
else
199-
weights_.MatrixDotVector(i_input, output_line);
188+
void FullyConnected::ForwardTimeStep(int t, double* output_line) {
200189
if (type_ == NT_TANH) {
201190
FuncInplace<GFunc>(no_, output_line);
202191
} else if (type_ == NT_LOGISTIC) {
@@ -214,6 +203,22 @@ void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_inpu
214203
}
215204
}
216205

206+
void FullyConnected::ForwardTimeStep(const double* d_input,
207+
int t, double* output_line) {
208+
// input is copied to source_ line-by-line for cache coherency.
209+
if (IsTraining() && external_source_ == NULL)
210+
source_t_.WriteStrided(t, d_input);
211+
weights_.MatrixDotVector(d_input, output_line);
212+
ForwardTimeStep(t, output_line);
213+
}
214+
215+
void FullyConnected::ForwardTimeStep(const int8_t* i_input,
216+
int t, double* output_line) {
217+
// input is copied to source_ line-by-line for cache coherency.
218+
weights_.MatrixDotVector(i_input, output_line);
219+
ForwardTimeStep(t, output_line);
220+
}
221+
217222
// Runs backward propagation of errors on the deltas line.
218223
// See NetworkCpp for a detailed discussion of the arguments.
219224
bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,

lstm/fullyconnected.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ class FullyConnected : public Network {
9191
// Components of Forward so FullyConnected can be reused inside LSTM.
9292
void SetupForward(const NetworkIO& input,
9393
const TransposedArray* input_transpose);
94-
void ForwardTimeStep(const double* d_input, const int8_t* i_input, int t,
95-
double* output_line);
94+
void ForwardTimeStep(int t, double* output_line);
95+
void ForwardTimeStep(const double* d_input, int t, double* output_line);
96+
void ForwardTimeStep(const int8_t* i_input, int t, double* output_line);
9697

9798
// Runs backward propagation of errors on the deltas line.
9899
// See Network for a detailed discussion of the arguments.

lstm/lstm.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ void LSTM::Forward(bool debug, const NetworkIO& input,
396396
if (softmax_ != nullptr) {
397397
if (input.int_mode()) {
398398
int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
399-
softmax_->ForwardTimeStep(nullptr, int_output->i(0), t, softmax_output);
399+
softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
400400
} else {
401-
softmax_->ForwardTimeStep(curr_output, nullptr, t, softmax_output);
401+
softmax_->ForwardTimeStep(curr_output, t, softmax_output);
402402
}
403403
output->WriteTimeStep(t, softmax_output);
404404
if (type_ == NT_LSTM_SOFTMAX_ENCODED) {

0 commit comments

Comments
 (0)