32
32
#include < vector>
33
33
#include < string>
34
34
#include < utility>
35
- #include < algorithm>
36
35
#include " ../operator_common.h"
37
36
#include " ../elemwise_op_common.h"
38
37
#include " ../linalg.h"
@@ -60,7 +59,6 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
60
59
int num_hidden;
61
60
bool no_bias;
62
61
bool flatten;
63
-
64
62
DMLC_DECLARE_PARAMETER (FullyConnectedParam) {
65
63
// TODO(bing) add support for boolean
66
64
DMLC_DECLARE_FIELD (num_hidden).set_lower_bound (1 )
@@ -77,66 +75,6 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
77
75
}
78
76
};
79
77
80
- template <typename DType>
81
- void AddBias (Tensor<cpu, 1 , DType> bias, Tensor<cpu, 2 , DType> data,
82
- Tensor<cpu, 2 , DType> out, Stream<cpu>*) {
83
- using namespace mshadow ;
84
- using namespace mshadow ::expr;
85
- out += repmat (bias, data.size (0 ));
86
- }
87
-
88
- #if defined(__CUDACC__)
89
-
90
- namespace {
91
- constexpr int nthreads_addbias = 256 ;
92
- constexpr int nthreads_addbiasgrad_phase1 = 512 ;
93
- constexpr int nthreads_addbiasgrad_phase2 = 128 ;
94
- constexpr int threads_per_warp = 32 ;
95
-
96
- inline int ceil_div (int x, int y) {
97
- return (x + y - 1 ) / y;
98
- }
99
- } // namespace
100
-
101
- template <typename DType, typename LType>
102
- __global__ void add_bias_kernel (DType* mat, DType* bias, size_t lead_dim, size_t bias_length) {
103
- __shared__ LType scratch[nthreads_addbias * 2 ];
104
- const index_t N = bias_length * sizeof (DType)/sizeof (LType);
105
- const index_t base = blockIdx.x * N;
106
- LType* const mat_aligned = reinterpret_cast <LType*>(mat) + base;
107
- const LType* const bias_aligned = reinterpret_cast <LType*>(bias);
108
- LType* const scratch_bias_load = scratch + threadIdx.x ;
109
- DType* const scratch_bias = reinterpret_cast <DType*>(scratch_bias_load);
110
- LType* const scratch_mat_load = scratch_bias_load + nthreads_addbias;
111
- DType* const scratch_mat = reinterpret_cast <DType*>(scratch_mat_load);
112
- for (index_t i = threadIdx.x ; i < N; i += blockDim.x ) {
113
- *scratch_bias_load = bias_aligned[i];
114
- *scratch_mat_load = mat_aligned[i];
115
- #pragma unroll
116
- for (int j = 0 ; j < sizeof (LType)/sizeof (DType); ++j) {
117
- scratch_mat[j] += scratch_bias[j];
118
- }
119
- mat_aligned[i] = *scratch_mat_load;
120
- }
121
- }
122
-
123
- template <typename DType>
124
- void AddBias (Tensor<gpu, 1 , DType> bias, Tensor<gpu, 2 , DType> data,
125
- Tensor<gpu, 2 , DType> out, Stream<gpu>* s) {
126
- int ltype = mxnet::common::cuda::get_load_type (bias.shape_ [0 ] * sizeof (DType));
127
- MXNET_LOAD_TYPE_SWITCH (ltype, LType, {
128
- add_bias_kernel<DType, LType><<<data.size (0 ),
129
- nthreads_addbias,
130
- 0 ,
131
- Stream<gpu>::GetStream (s)>>>(out.dptr_ ,
132
- bias.dptr_ ,
133
- data.size (0 ),
134
- bias.shape_ [0 ]);
135
- });
136
- }
137
-
138
- #endif // __CUDACC__
139
-
140
78
template <typename xpu, typename DType>
141
79
void FCForward (const OpContext &ctx, const FullyConnectedParam ¶m,
142
80
const std::vector<TBlob> &in_data, const std::vector<OpReqType> &req,
@@ -184,153 +122,10 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam ¶m,
184
122
<< " Incomplete bias tensor detected: bias.data().shape[1] != weight.data().shape[0]."
185
123
" This is not supported by FCForward. If bias is in row_sparse format, please"
186
124
" make sure all row ids are present." ;
187
- AddBias (bias, data, out, s );
125
+ out += repmat (bias, data. size ( 0 ) );
188
126
}
189
127
}
190
128
191
- #if defined (__CUDACC__)
192
-
193
- template <typename LType, typename DType, typename AType>
194
- __global__ void AddBiasGradKernelPhase1 (AType * temp_space, const DType* grad,
195
- const size_t lead_dim, const size_t other_dim) {
196
- constexpr int num_warps = nthreads_addbiasgrad_phase1 / threads_per_warp;
197
- const int values_per_read = sizeof (LType) >= sizeof (DType) ? sizeof (LType) / sizeof (DType) : 1 ;
198
- const size_t stride = lead_dim / values_per_read;
199
- __shared__ AType scratch[threads_per_warp * num_warps * values_per_read];
200
- LType * my_scratch_load = &(reinterpret_cast <LType *>(scratch)[threadIdx.x ]);
201
- DType * my_values_load = reinterpret_cast <DType *>(my_scratch_load);
202
- AType * my_values_acc = &(scratch[threadIdx.x * values_per_read]);
203
- AType acc[values_per_read]; // NOLINT(*)
204
- #pragma unroll
205
- for (int i = 0 ; i < values_per_read; ++i) {
206
- acc[i] = 0 ;
207
- }
208
- const size_t offset = blockIdx.x * threads_per_warp;
209
- const int my_warp = threadIdx.x / threads_per_warp;
210
- const int my_id = threadIdx.x % threads_per_warp;
211
- const LType* aligned_grad = reinterpret_cast <const LType*>(grad);
212
- const int rows_per_block = (other_dim + gridDim.y - 1 ) / gridDim.y ;
213
- const size_t start_row = my_warp + rows_per_block * blockIdx.y ;
214
- const size_t end_row = min (other_dim, static_cast <size_t >(rows_per_block * (blockIdx.y + 1 )));
215
- if (offset + my_id < stride) {
216
- for (size_t i = start_row; i < end_row; i += num_warps) {
217
- *my_scratch_load = aligned_grad[i * stride + offset + my_id];
218
- #pragma unroll
219
- for (int j = 0 ; j < values_per_read; ++j) {
220
- acc[j] += static_cast <AType>(my_values_load[j]);
221
- }
222
- }
223
- }
224
- __syncthreads ();
225
- #pragma unroll
226
- for (int i = 0 ; i < values_per_read; ++i) {
227
- my_values_acc[i] = acc[i];
228
- }
229
-
230
- __syncthreads ();
231
-
232
- for (int i = num_warps / 2 ; i > 0 ; i /= 2 ) {
233
- if (my_warp < i) {
234
- const int shared_offset = values_per_read * i * threads_per_warp;
235
- #pragma unroll
236
- for (int j = 0 ; j < values_per_read; ++j) {
237
- my_values_acc[j] += my_values_acc[j + shared_offset];
238
- }
239
- }
240
- __syncthreads ();
241
- }
242
-
243
- if (threadIdx.x < min (threads_per_warp * values_per_read,
244
- static_cast <int >(lead_dim - values_per_read * offset))) {
245
- const size_t offset_out = values_per_read * offset +
246
- blockIdx.y * lead_dim;
247
- temp_space[offset_out + threadIdx.x ] = scratch[threadIdx.x ];
248
- }
249
- }
250
-
251
- template <typename DType, typename AType>
252
- __global__ void AddBiasGradKernelPhase2 (const AType * temp_space, DType * out,
253
- int lead_dim, int n_blocks, OpReqType req) {
254
- int tid = threadIdx.x + blockIdx.x * blockDim.x ;
255
- if (tid < lead_dim) {
256
- AType acc = 0 ;
257
- for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) {
258
- acc += temp_space[i];
259
- }
260
- KERNEL_ASSIGN (out[tid], req, static_cast <DType>(acc));
261
- }
262
- }
263
-
264
- template <typename DType>
265
- void AddBiasGrad (const TBlob& in_grad,
266
- Tensor<gpu, 2 , DType> grad,
267
- OpReqType req,
268
- int num_hidden,
269
- const OpContext& ctx) {
270
- if (req == kNullOp ) return ;
271
- using AType = typename mxnet_op::AccType<DType>::type;
272
- mshadow::Stream<gpu> *s = ctx.get_stream <gpu>();
273
- Tensor<gpu, 1 , DType> gbias = in_grad.get <gpu, 1 , DType>(s);
274
- TBlob grad_blob = TBlob (grad);
275
- TBlob gbias_blob = TBlob (gbias);
276
- mxnet::TShape x (1 , 0 );
277
- mxnet::TShape small;
278
- if (shape_assign (&gbias_blob.shape_ , Shape2 (num_hidden, 1 ))) {
279
- small = gbias_blob.shape_ ;
280
- } else {
281
- small = ReduceAxesShapeImpl (grad_blob.shape_ , dmlc::optional<mxnet::TShape>(x), true , false );
282
- }
283
- const int N = small.Size ();
284
- int ltype = mxnet::common::cuda::get_load_type (N * sizeof (DType));
285
- const int M = grad_blob.shape_ .Size () / N;
286
- MXNET_LOAD_TYPE_SWITCH (ltype, LType, {
287
- const unsigned int blocks_x = ceil_div (N * sizeof (DType),
288
- threads_per_warp * sizeof (LType));
289
- const unsigned int preferred_number_of_blocks = 2 *
290
- MultiprocessorCount (ctx.run_ctx .ctx .dev_id );
291
- const unsigned int blocks_y = std::max (preferred_number_of_blocks / blocks_x, 1u );
292
- const dim3 n_blocks = {blocks_x, blocks_y, 1 };
293
- auto scratch_space = ctx.requested [fullc::kTempSpace ]
294
- .get_space_typed <gpu, 1 , AType>(mshadow::Shape1 (N * blocks_y), s);
295
- auto stream = mshadow::Stream<gpu>::GetStream (s);
296
- AddBiasGradKernelPhase1<LType><<<n_blocks,
297
- nthreads_addbiasgrad_phase1,
298
- 0 ,
299
- stream>>>(scratch_space.dptr_ ,
300
- grad.dptr_ , N, M);
301
- const int nblocks_phase2 = ceil_div (N, nthreads_addbiasgrad_phase2);
302
- AddBiasGradKernelPhase2<<<nblocks_phase2,
303
- nthreads_addbiasgrad_phase2,
304
- 0 ,
305
- stream>>>(scratch_space.dptr_ ,
306
- gbias.dptr_ , N,
307
- blocks_y, req);
308
- });
309
- }
310
- #endif
311
-
312
- template <typename DType>
313
- void AddBiasGrad (const TBlob& in_grad,
314
- Tensor<cpu, 2 , DType> grad,
315
- OpReqType req,
316
- int num_hidden,
317
- const OpContext& ctx) {
318
- mshadow::Stream<cpu> *s = ctx.get_stream <cpu>();
319
- Tensor<cpu, 1 , DType> gbias = in_grad.get <cpu, 1 , DType>(s);
320
- TBlob grad_blob = TBlob (grad);
321
- TBlob gbias_blob = TBlob (gbias);
322
- mxnet::TShape x (1 , 0 );
323
- mxnet::TShape small;
324
- if (shape_assign (&gbias_blob.shape_ , Shape2 (num_hidden, 1 ))) {
325
- small = gbias_blob.shape_ ;
326
- } else {
327
- small = ReduceAxesShapeImpl (grad_blob.shape_ , dmlc::optional<mxnet::TShape>(x), true , false );
328
- }
329
- ReduceAxesComputeImpl<cpu, mshadow::red::sum, false , false ,
330
- mshadow_op::identity>(ctx, {grad_blob}, {req},
331
- {in_grad}, small);
332
- }
333
-
334
129
template <typename xpu, typename DType>
335
130
void FCBackward (const OpContext &ctx, const FullyConnectedParam ¶m,
336
131
const std::vector<TBlob> &out_grad, const std::vector<TBlob> &in_data,
@@ -374,7 +169,19 @@ void FCBackward(const OpContext &ctx, const FullyConnectedParam ¶m,
374
169
linalg_gemm (grad, data, gwmat, true , false , s, req[fullc::kWeight ]);
375
170
// gradient of bias
376
171
if (!param.no_bias ) {
377
- AddBiasGrad (in_grad[fullc::kBias ], grad, req[fullc::kBias ], param.num_hidden , ctx);
172
+ Tensor<xpu, 1 , DType> gbias = in_grad[fullc::kBias ].get <xpu, 1 , DType>(s);
173
+ TBlob grad_blob = TBlob (grad);
174
+ TBlob gbias_blob = TBlob (gbias);
175
+ mxnet::TShape x (1 , 0 );
176
+ mxnet::TShape small;
177
+ if (shape_assign (&gbias_blob.shape_ , Shape2 (param.num_hidden , 1 ))) {
178
+ small = gbias_blob.shape_ ;
179
+ } else {
180
+ small = ReduceAxesShapeImpl (grad_blob.shape_ , dmlc::optional<mxnet::TShape>(x), true , false );
181
+ }
182
+ ReduceAxesComputeImpl<xpu, mshadow::red::sum, false , false ,
183
+ mshadow_op::identity>(ctx, {grad_blob}, {req[fullc::kBias ]},
184
+ {in_grad[fullc::kBias ]}, small);
378
185
}
379
186
// gradient of data
380
187
// Legacy approach shown here for comparison:
0 commit comments