Skip to content

Commit bcd6fac

Browse files
ptrendxlarroy
authored andcommitted
FullyConnected Bias performance improvement on GPU (apache#16039)
* FullyConnected Bias performance improvement on GPU * Handle req properly * Fix after rebase * More fixes from rebase * Fix lint * Trigger CI * Fixes from review * Fix
1 parent 88f382e commit bcd6fac

File tree

3 files changed

+285
-22
lines changed

3 files changed

+285
-22
lines changed

src/common/cuda_utils.h

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ extern __cuda_fake_struct blockIdx;
5757
#include <cublas_v2.h>
5858
#include <curand.h>
5959

60+
#include <vector>
61+
6062
#define STATIC_ASSERT_CUDA_VERSION_GE(min_version) \
6163
static_assert(CUDA_VERSION >= min_version, "Compiled-against CUDA version " \
6264
QUOTEVALUE(CUDA_VERSION) " is too old, please upgrade system to version " \
@@ -353,16 +355,41 @@ int get_rows_per_block(size_t row_size, int num_threads_per_block);
353355
} // namespace common
354356
} // namespace mxnet
355357

358+
/*! \brief Maximum number of GPUs */
359+
constexpr size_t kMaxNumGpus = 64;
360+
361+
// The implementations below assume that accesses of 32-bit ints are inherently atomic and
362+
// can be read/written by multiple threads without locks. The values held should be < 2^31.
363+
364+
/*!
365+
* \brief Return an attribute GPU `device_id`.
366+
* \param device_id The device index of the cuda-capable gpu of interest.
367+
* \param cached_values An array of attributes for already-looked-up GPUs.
368+
* \param attr The attribute, by number.
369+
* \param attr_name A string representation of the attribute, for error messages.
370+
* \return the gpu's attribute value.
371+
*/
372+
inline int cudaAttributeLookup(int device_id, std::vector<int32_t> *cached_values,
373+
cudaDeviceAttr attr, const char *attr_name) {
374+
if (device_id < 0 || device_id >= static_cast<int>(cached_values->size())) {
375+
LOG(FATAL) << attr_name << "(device_id) called with invalid id: " << device_id;
376+
} else if ((*cached_values)[device_id] < 0) {
377+
int temp = -1;
378+
CUDA_CALL(cudaDeviceGetAttribute(&temp, attr, device_id));
379+
(*cached_values)[device_id] = static_cast<int32_t>(temp);
380+
}
381+
return (*cached_values)[device_id];
382+
}
383+
356384
/*!
357385
* \brief Determine major version number of the gpu's cuda compute architecture.
358386
* \param device_id The device index of the cuda-capable gpu of interest.
359387
* \return the major version number of the gpu's cuda compute architecture.
360388
*/
361389
inline int ComputeCapabilityMajor(int device_id) {
362-
int major = 0;
363-
CUDA_CALL(cudaDeviceGetAttribute(&major,
364-
cudaDevAttrComputeCapabilityMajor, device_id));
365-
return major;
390+
static std::vector<int32_t> capability_major(kMaxNumGpus, -1);
391+
return cudaAttributeLookup(device_id, &capability_major,
392+
cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor");
366393
}
367394

368395
/*!
@@ -371,10 +398,9 @@ inline int ComputeCapabilityMajor(int device_id) {
371398
* \return the minor version number of the gpu's cuda compute architecture.
372399
*/
373400
inline int ComputeCapabilityMinor(int device_id) {
374-
int minor = 0;
375-
CUDA_CALL(cudaDeviceGetAttribute(&minor,
376-
cudaDevAttrComputeCapabilityMinor, device_id));
377-
return minor;
401+
static std::vector<int32_t> capability_minor(kMaxNumGpus, -1);
402+
return cudaAttributeLookup(device_id, &capability_minor,
403+
cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor");
378404
}
379405

380406
/*!
@@ -388,6 +414,40 @@ inline int SMArch(int device_id) {
388414
return 10 * major + minor;
389415
}
390416

417+
/*!
418+
* \brief Return the number of streaming multiprocessors of GPU `device_id`.
419+
* \param device_id The device index of the cuda-capable gpu of interest.
420+
* \return the gpu's count of streaming multiprocessors.
421+
*/
422+
inline int MultiprocessorCount(int device_id) {
423+
static std::vector<int32_t> sm_counts(kMaxNumGpus, -1);
424+
return cudaAttributeLookup(device_id, &sm_counts,
425+
cudaDevAttrMultiProcessorCount, "MultiprocessorCount");
426+
}
427+
428+
/*!
429+
* \brief Return the shared memory size in bytes of each of the GPU's streaming multiprocessors.
430+
* \param device_id The device index of the cuda-capable gpu of interest.
431+
* \return the shared memory size per streaming multiprocessor.
432+
*/
433+
inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
434+
static std::vector<int32_t> max_smem_per_mutiprocessor(kMaxNumGpus, -1);
435+
return cudaAttributeLookup(device_id, &max_smem_per_mutiprocessor,
436+
cudaDevAttrMaxSharedMemoryPerMultiprocessor,
437+
"MaxSharedMemoryPerMultiprocessor");
438+
}
439+
440+
/*!
441+
* \brief Return whether the GPU `device_id` supports cooperative-group kernel launching.
442+
* \param device_id The device index of the cuda-capable gpu of interest.
443+
* \return the gpu's ability to run cooperative-group kernels.
444+
*/
445+
inline bool SupportsCooperativeLaunch(int device_id) {
446+
static std::vector<int32_t> coop_launch(kMaxNumGpus, -1);
447+
return cudaAttributeLookup(device_id, &coop_launch,
448+
cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch");
449+
}
450+
391451
/*!
392452
* \brief Determine whether a cuda-capable gpu's architecture supports float16 math.
393453
* Assume not if device_id is negative.

src/operator/mxnet_op.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,16 @@ inline int get_num_threads<cpu>(const int N) {
249249
LOG(FATAL) << "Unknown type enum " << type; \
250250
}
251251

252+
template <typename T>
253+
struct AccType {
254+
using type = T;
255+
};
256+
257+
template <>
258+
struct AccType<mshadow::half::half_t> {
259+
using type = float;
260+
};
261+
252262
#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\
253263
switch (type) { \
254264
case mshadow::kFloat32: \

src/operator/nn/fully_connected-inl.h

Lines changed: 207 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <vector>
3333
#include <string>
3434
#include <utility>
35+
#include <algorithm>
3536
#include "../operator_common.h"
3637
#include "../elemwise_op_common.h"
3738
#include "../linalg.h"
@@ -59,6 +60,7 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
5960
int num_hidden;
6061
bool no_bias;
6162
bool flatten;
63+
6264
DMLC_DECLARE_PARAMETER(FullyConnectedParam) {
6365
// TODO(bing) add support for boolean
6466
DMLC_DECLARE_FIELD(num_hidden).set_lower_bound(1)
@@ -75,6 +77,66 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
7577
}
7678
};
7779

80+
template<typename DType>
81+
void AddBias(Tensor<cpu, 1, DType> bias, Tensor<cpu, 2, DType> data,
82+
Tensor<cpu, 2, DType> out, Stream<cpu>*) {
83+
using namespace mshadow;
84+
using namespace mshadow::expr;
85+
out += repmat(bias, data.size(0));
86+
}
87+
88+
#if defined(__CUDACC__)
89+
90+
namespace {
91+
constexpr int nthreads_addbias = 256;
92+
constexpr int nthreads_addbiasgrad_phase1 = 512;
93+
constexpr int nthreads_addbiasgrad_phase2 = 128;
94+
constexpr int threads_per_warp = 32;
95+
96+
inline int ceil_div(int x, int y) {
97+
return (x + y - 1) / y;
98+
}
99+
} // namespace
100+
101+
template <typename DType, typename LType>
102+
__global__ void add_bias_kernel(DType* mat, DType* bias, size_t lead_dim, size_t bias_length) {
103+
__shared__ LType scratch[nthreads_addbias * 2];
104+
const index_t N = bias_length * sizeof(DType)/sizeof(LType);
105+
const index_t base = blockIdx.x * N;
106+
LType* const mat_aligned = reinterpret_cast<LType*>(mat) + base;
107+
const LType* const bias_aligned = reinterpret_cast<LType*>(bias);
108+
LType* const scratch_bias_load = scratch + threadIdx.x;
109+
DType* const scratch_bias = reinterpret_cast<DType*>(scratch_bias_load);
110+
LType* const scratch_mat_load = scratch_bias_load + nthreads_addbias;
111+
DType* const scratch_mat = reinterpret_cast<DType*>(scratch_mat_load);
112+
for (index_t i = threadIdx.x; i < N; i += blockDim.x) {
113+
*scratch_bias_load = bias_aligned[i];
114+
*scratch_mat_load = mat_aligned[i];
115+
#pragma unroll
116+
for (int j = 0; j < sizeof(LType)/sizeof(DType); ++j) {
117+
scratch_mat[j] += scratch_bias[j];
118+
}
119+
mat_aligned[i] = *scratch_mat_load;
120+
}
121+
}
122+
123+
template<typename DType>
124+
void AddBias(Tensor<gpu, 1, DType> bias, Tensor<gpu, 2, DType> data,
125+
Tensor<gpu, 2, DType> out, Stream<gpu>* s) {
126+
int ltype = mxnet::common::cuda::get_load_type(bias.shape_[0] * sizeof(DType));
127+
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
128+
add_bias_kernel<DType, LType><<<data.size(0),
129+
nthreads_addbias,
130+
0,
131+
Stream<gpu>::GetStream(s)>>>(out.dptr_,
132+
bias.dptr_,
133+
data.size(0),
134+
bias.shape_[0]);
135+
});
136+
}
137+
138+
#endif // __CUDACC__
139+
78140
template<typename xpu, typename DType>
79141
void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
80142
const std::vector<TBlob> &in_data, const std::vector<OpReqType> &req,
@@ -122,10 +184,153 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
122184
<< "Incomplete bias tensor detected: bias.data().shape[1] != weight.data().shape[0]."
123185
" This is not supported by FCForward. If bias is in row_sparse format, please"
124186
" make sure all row ids are present.";
125-
out += repmat(bias, data.size(0));
187+
AddBias(bias, data, out, s);
126188
}
127189
}
128190

191+
#if defined (__CUDACC__)
192+
193+
template<typename LType, typename DType, typename AType>
194+
__global__ void AddBiasGradKernelPhase1(AType * temp_space, const DType* grad,
195+
const size_t lead_dim, const size_t other_dim) {
196+
constexpr int num_warps = nthreads_addbiasgrad_phase1 / threads_per_warp;
197+
const int values_per_read = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
198+
const size_t stride = lead_dim / values_per_read;
199+
__shared__ AType scratch[threads_per_warp * num_warps * values_per_read];
200+
LType * my_scratch_load = &(reinterpret_cast<LType *>(scratch)[threadIdx.x]);
201+
DType * my_values_load = reinterpret_cast<DType *>(my_scratch_load);
202+
AType * my_values_acc = &(scratch[threadIdx.x * values_per_read]);
203+
AType acc[values_per_read]; // NOLINT(*)
204+
#pragma unroll
205+
for (int i = 0; i < values_per_read; ++i) {
206+
acc[i] = 0;
207+
}
208+
const size_t offset = blockIdx.x * threads_per_warp;
209+
const int my_warp = threadIdx.x / threads_per_warp;
210+
const int my_id = threadIdx.x % threads_per_warp;
211+
const LType* aligned_grad = reinterpret_cast<const LType*>(grad);
212+
const int rows_per_block = (other_dim + gridDim.y - 1) / gridDim.y;
213+
const size_t start_row = my_warp + rows_per_block * blockIdx.y;
214+
const size_t end_row = min(other_dim, static_cast<size_t>(rows_per_block * (blockIdx.y + 1)));
215+
if (offset + my_id < stride) {
216+
for (size_t i = start_row; i < end_row; i += num_warps) {
217+
*my_scratch_load = aligned_grad[i * stride + offset + my_id];
218+
#pragma unroll
219+
for (int j = 0; j < values_per_read; ++j) {
220+
acc[j] += static_cast<AType>(my_values_load[j]);
221+
}
222+
}
223+
}
224+
__syncthreads();
225+
#pragma unroll
226+
for (int i = 0; i < values_per_read; ++i) {
227+
my_values_acc[i] = acc[i];
228+
}
229+
230+
__syncthreads();
231+
232+
for (int i = num_warps / 2; i > 0; i /= 2) {
233+
if (my_warp < i) {
234+
const int shared_offset = values_per_read * i * threads_per_warp;
235+
#pragma unroll
236+
for (int j = 0; j < values_per_read; ++j) {
237+
my_values_acc[j] += my_values_acc[j + shared_offset];
238+
}
239+
}
240+
__syncthreads();
241+
}
242+
243+
if (threadIdx.x < min(threads_per_warp * values_per_read,
244+
static_cast<int>(lead_dim - values_per_read * offset))) {
245+
const size_t offset_out = values_per_read * offset +
246+
blockIdx.y * lead_dim;
247+
temp_space[offset_out + threadIdx.x] = scratch[threadIdx.x];
248+
}
249+
}
250+
251+
template <typename DType, typename AType>
252+
__global__ void AddBiasGradKernelPhase2(const AType * temp_space, DType * out,
253+
int lead_dim, int n_blocks, OpReqType req) {
254+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
255+
if (tid < lead_dim) {
256+
AType acc = 0;
257+
for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) {
258+
acc += temp_space[i];
259+
}
260+
KERNEL_ASSIGN(out[tid], req, static_cast<DType>(acc));
261+
}
262+
}
263+
264+
template<typename DType>
265+
void AddBiasGrad(const TBlob& in_grad,
266+
Tensor<gpu, 2, DType> grad,
267+
OpReqType req,
268+
int num_hidden,
269+
const OpContext& ctx) {
270+
if (req == kNullOp) return;
271+
using AType = typename mxnet_op::AccType<DType>::type;
272+
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
273+
Tensor<gpu, 1, DType> gbias = in_grad.get<gpu, 1, DType>(s);
274+
TBlob grad_blob = TBlob(grad);
275+
TBlob gbias_blob = TBlob(gbias);
276+
mxnet::TShape x(1, 0);
277+
mxnet::TShape small;
278+
if (shape_assign(&gbias_blob.shape_, Shape2(num_hidden, 1))) {
279+
small = gbias_blob.shape_;
280+
} else {
281+
small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional<mxnet::TShape>(x), true, false);
282+
}
283+
const int N = small.Size();
284+
int ltype = mxnet::common::cuda::get_load_type(N * sizeof(DType));
285+
const int M = grad_blob.shape_.Size() / N;
286+
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
287+
const unsigned int blocks_x = ceil_div(N * sizeof(DType),
288+
threads_per_warp * sizeof(LType));
289+
const unsigned int preferred_number_of_blocks = 2 *
290+
MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
291+
const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u);
292+
const dim3 n_blocks = {blocks_x, blocks_y, 1};
293+
auto scratch_space = ctx.requested[fullc::kTempSpace]
294+
.get_space_typed<gpu, 1, AType>(mshadow::Shape1(N * blocks_y), s);
295+
auto stream = mshadow::Stream<gpu>::GetStream(s);
296+
AddBiasGradKernelPhase1<LType><<<n_blocks,
297+
nthreads_addbiasgrad_phase1,
298+
0,
299+
stream>>>(scratch_space.dptr_,
300+
grad.dptr_, N, M);
301+
const int nblocks_phase2 = ceil_div(N, nthreads_addbiasgrad_phase2);
302+
AddBiasGradKernelPhase2<<<nblocks_phase2,
303+
nthreads_addbiasgrad_phase2,
304+
0,
305+
stream>>>(scratch_space.dptr_,
306+
gbias.dptr_, N,
307+
blocks_y, req);
308+
});
309+
}
310+
#endif
311+
312+
template<typename DType>
313+
void AddBiasGrad(const TBlob& in_grad,
314+
Tensor<cpu, 2, DType> grad,
315+
OpReqType req,
316+
int num_hidden,
317+
const OpContext& ctx) {
318+
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
319+
Tensor<cpu, 1, DType> gbias = in_grad.get<cpu, 1, DType>(s);
320+
TBlob grad_blob = TBlob(grad);
321+
TBlob gbias_blob = TBlob(gbias);
322+
mxnet::TShape x(1, 0);
323+
mxnet::TShape small;
324+
if (shape_assign(&gbias_blob.shape_, Shape2(num_hidden, 1))) {
325+
small = gbias_blob.shape_;
326+
} else {
327+
small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional<mxnet::TShape>(x), true, false);
328+
}
329+
ReduceAxesComputeImpl<cpu, mshadow::red::sum, false, false,
330+
mshadow_op::identity>(ctx, {grad_blob}, {req},
331+
{in_grad}, small);
332+
}
333+
129334
template<typename xpu, typename DType>
130335
void FCBackward(const OpContext &ctx, const FullyConnectedParam &param,
131336
const std::vector<TBlob> &out_grad, const std::vector<TBlob> &in_data,
@@ -169,19 +374,7 @@ void FCBackward(const OpContext &ctx, const FullyConnectedParam &param,
169374
linalg_gemm(grad, data, gwmat, true, false, s, req[fullc::kWeight]);
170375
// gradient of bias
171376
if (!param.no_bias) {
172-
Tensor<xpu, 1, DType> gbias = in_grad[fullc::kBias].get<xpu, 1, DType>(s);
173-
TBlob grad_blob = TBlob(grad);
174-
TBlob gbias_blob = TBlob(gbias);
175-
mxnet::TShape x(1, 0);
176-
mxnet::TShape small;
177-
if (shape_assign(&gbias_blob.shape_, Shape2(param.num_hidden, 1))) {
178-
small = gbias_blob.shape_;
179-
} else {
180-
small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional<mxnet::TShape>(x), true, false);
181-
}
182-
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false,
183-
mshadow_op::identity>(ctx, {grad_blob}, {req[fullc::kBias]},
184-
{in_grad[fullc::kBias]}, small);
377+
AddBiasGrad(in_grad[fullc::kBias], grad, req[fullc::kBias], param.num_hidden, ctx);
185378
}
186379
// gradient of data
187380
// Legacy approach shown here for comparison:

0 commit comments

Comments
 (0)