@@ -147,15 +147,12 @@ void FullyConnected::Forward(bool debug, const NetworkIO& input,
147
147
int thread_id = 0 ;
148
148
#endif
149
149
double * temp_line = temp_lines[thread_id];
150
- const double * d_input = nullptr ;
151
- const int8_t * i_input = nullptr ;
152
150
if (input.int_mode ()) {
153
- i_input = input.i (t);
151
+ ForwardTimeStep ( input.i (t), t, temp_line );
154
152
} else {
155
153
input.ReadTimeStep (t, curr_input[thread_id]);
156
- d_input = curr_input[thread_id];
154
+ ForwardTimeStep ( curr_input[thread_id], t, temp_line) ;
157
155
}
158
- ForwardTimeStep (d_input, i_input, t, temp_line);
159
156
output->WriteTimeStep (t, temp_line);
160
157
if (IsTraining () && type_ != NT_SOFTMAX) {
161
158
acts_.CopyTimeStepFrom (t, *output, t);
@@ -188,15 +185,7 @@ void FullyConnected::SetupForward(const NetworkIO& input,
188
185
}
189
186
}
190
187
191
- void FullyConnected::ForwardTimeStep (const double * d_input, const int8_t * i_input,
192
- int t, double * output_line) {
193
- // input is copied to source_ line-by-line for cache coherency.
194
- if (IsTraining () && external_source_ == nullptr && d_input != nullptr )
195
- source_t_.WriteStrided (t, d_input);
196
- if (d_input != nullptr )
197
- weights_.MatrixDotVector (d_input, output_line);
198
- else
199
- weights_.MatrixDotVector (i_input, output_line);
188
+ void FullyConnected::ForwardTimeStep (int t, double * output_line) {
200
189
if (type_ == NT_TANH) {
201
190
FuncInplace<GFunc>(no_, output_line);
202
191
} else if (type_ == NT_LOGISTIC) {
@@ -214,6 +203,22 @@ void FullyConnected::ForwardTimeStep(const double* d_input, const int8_t* i_inpu
214
203
}
215
204
}
216
205
206
+ void FullyConnected::ForwardTimeStep (const double * d_input,
207
+ int t, double * output_line) {
208
+ // input is copied to source_ line-by-line for cache coherency.
209
+ if (IsTraining () && external_source_ == NULL )
210
+ source_t_.WriteStrided (t, d_input);
211
+ weights_.MatrixDotVector (d_input, output_line);
212
+ ForwardTimeStep (t, output_line);
213
+ }
214
+
215
+ void FullyConnected::ForwardTimeStep (const int8_t * i_input,
216
+ int t, double * output_line) {
217
+ // input is copied to source_ line-by-line for cache coherency.
218
+ weights_.MatrixDotVector (i_input, output_line);
219
+ ForwardTimeStep (t, output_line);
220
+ }
221
+
217
222
// Runs backward propagation of errors on the deltas line.
218
223
// See NetworkCpp for a detailed discussion of the arguments.
219
224
bool FullyConnected::Backward (bool debug, const NetworkIO& fwd_deltas,
0 commit comments