Skip to content

Commit 745c25c

Browse files
Revert "FullyConnected Bias performance improvement on GPU (apache#16039)"
This reverts commit a5e698a.
1 parent ecda000 commit 745c25c

File tree

3 files changed

+22
-285
lines changed

3 files changed

+22
-285
lines changed

src/common/cuda_utils.h

Lines changed: 8 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ extern __cuda_fake_struct blockIdx;
5757
#include <cublas_v2.h>
5858
#include <curand.h>
5959

60-
#include <vector>
61-
6260
#define STATIC_ASSERT_CUDA_VERSION_GE(min_version) \
6361
static_assert(CUDA_VERSION >= min_version, "Compiled-against CUDA version " \
6462
QUOTEVALUE(CUDA_VERSION) " is too old, please upgrade system to version " \
@@ -429,41 +427,16 @@ int get_rows_per_block(size_t row_size, int num_threads_per_block);
429427
} // namespace common
430428
} // namespace mxnet
431429

432-
/*! \brief Maximum number of GPUs */
433-
constexpr size_t kMaxNumGpus = 64;
434-
435-
// The implementations below assume that accesses of 32-bit ints are inherently atomic and
436-
// can be read/written by multiple threads without locks. The values held should be < 2^31.
437-
438-
/*!
439-
* \brief Return an attribute GPU `device_id`.
440-
* \param device_id The device index of the cuda-capable gpu of interest.
441-
* \param cached_values An array of attributes for already-looked-up GPUs.
442-
* \param attr The attribute, by number.
443-
* \param attr_name A string representation of the attribute, for error messages.
444-
* \return the gpu's attribute value.
445-
*/
446-
inline int cudaAttributeLookup(int device_id, std::vector<int32_t> *cached_values,
447-
cudaDeviceAttr attr, const char *attr_name) {
448-
if (device_id < 0 || device_id >= static_cast<int>(cached_values->size())) {
449-
LOG(FATAL) << attr_name << "(device_id) called with invalid id: " << device_id;
450-
} else if ((*cached_values)[device_id] < 0) {
451-
int temp = -1;
452-
CUDA_CALL(cudaDeviceGetAttribute(&temp, attr, device_id));
453-
(*cached_values)[device_id] = static_cast<int32_t>(temp);
454-
}
455-
return (*cached_values)[device_id];
456-
}
457-
458430
/*!
459431
* \brief Determine major version number of the gpu's cuda compute architecture.
460432
* \param device_id The device index of the cuda-capable gpu of interest.
461433
* \return the major version number of the gpu's cuda compute architecture.
462434
*/
463435
inline int ComputeCapabilityMajor(int device_id) {
464-
static std::vector<int32_t> capability_major(kMaxNumGpus, -1);
465-
return cudaAttributeLookup(device_id, &capability_major,
466-
cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor");
436+
int major = 0;
437+
CUDA_CALL(cudaDeviceGetAttribute(&major,
438+
cudaDevAttrComputeCapabilityMajor, device_id));
439+
return major;
467440
}
468441

469442
/*!
@@ -472,9 +445,10 @@ inline int ComputeCapabilityMajor(int device_id) {
472445
* \return the minor version number of the gpu's cuda compute architecture.
473446
*/
474447
inline int ComputeCapabilityMinor(int device_id) {
475-
static std::vector<int32_t> capability_minor(kMaxNumGpus, -1);
476-
return cudaAttributeLookup(device_id, &capability_minor,
477-
cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor");
448+
int minor = 0;
449+
CUDA_CALL(cudaDeviceGetAttribute(&minor,
450+
cudaDevAttrComputeCapabilityMinor, device_id));
451+
return minor;
478452
}
479453

480454
/*!
@@ -488,40 +462,6 @@ inline int SMArch(int device_id) {
488462
return 10 * major + minor;
489463
}
490464

491-
/*!
492-
* \brief Return the number of streaming multiprocessors of GPU `device_id`.
493-
* \param device_id The device index of the cuda-capable gpu of interest.
494-
* \return the gpu's count of streaming multiprocessors.
495-
*/
496-
inline int MultiprocessorCount(int device_id) {
497-
static std::vector<int32_t> sm_counts(kMaxNumGpus, -1);
498-
return cudaAttributeLookup(device_id, &sm_counts,
499-
cudaDevAttrMultiProcessorCount, "MultiprocessorCount");
500-
}
501-
502-
/*!
503-
* \brief Return the shared memory size in bytes of each of the GPU's streaming multiprocessors.
504-
* \param device_id The device index of the cuda-capable gpu of interest.
505-
* \return the shared memory size per streaming multiprocessor.
506-
*/
507-
inline int MaxSharedMemoryPerMultiprocessor(int device_id) {
508-
static std::vector<int32_t> max_smem_per_mutiprocessor(kMaxNumGpus, -1);
509-
return cudaAttributeLookup(device_id, &max_smem_per_mutiprocessor,
510-
cudaDevAttrMaxSharedMemoryPerMultiprocessor,
511-
"MaxSharedMemoryPerMultiprocessor");
512-
}
513-
514-
/*!
515-
* \brief Return whether the GPU `device_id` supports cooperative-group kernel launching.
516-
* \param device_id The device index of the cuda-capable gpu of interest.
517-
* \return the gpu's ability to run cooperative-group kernels.
518-
*/
519-
inline bool SupportsCooperativeLaunch(int device_id) {
520-
static std::vector<int32_t> coop_launch(kMaxNumGpus, -1);
521-
return cudaAttributeLookup(device_id, &coop_launch,
522-
cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch");
523-
}
524-
525465
/*!
526466
* \brief Determine whether a cuda-capable gpu's architecture supports float16 math.
527467
* Assume not if device_id is negative.

src/operator/mxnet_op.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,16 +249,6 @@ inline int get_num_threads<cpu>(const int N) {
249249
LOG(FATAL) << "Unknown type enum " << type; \
250250
}
251251

252-
template <typename T>
253-
struct AccType {
254-
using type = T;
255-
};
256-
257-
template <>
258-
struct AccType<mshadow::half::half_t> {
259-
using type = float;
260-
};
261-
262252
#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\
263253
switch (type) { \
264254
case mshadow::kFloat32: \

src/operator/nn/fully_connected-inl.h

Lines changed: 14 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include <vector>
3333
#include <string>
3434
#include <utility>
35-
#include <algorithm>
3635
#include "../operator_common.h"
3736
#include "../elemwise_op_common.h"
3837
#include "../linalg.h"
@@ -60,7 +59,6 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
6059
int num_hidden;
6160
bool no_bias;
6261
bool flatten;
63-
6462
DMLC_DECLARE_PARAMETER(FullyConnectedParam) {
6563
// TODO(bing) add support for boolean
6664
DMLC_DECLARE_FIELD(num_hidden).set_lower_bound(1)
@@ -77,66 +75,6 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
7775
}
7876
};
7977

80-
template<typename DType>
81-
void AddBias(Tensor<cpu, 1, DType> bias, Tensor<cpu, 2, DType> data,
82-
Tensor<cpu, 2, DType> out, Stream<cpu>*) {
83-
using namespace mshadow;
84-
using namespace mshadow::expr;
85-
out += repmat(bias, data.size(0));
86-
}
87-
88-
#if defined(__CUDACC__)
89-
90-
namespace {
91-
constexpr int nthreads_addbias = 256;
92-
constexpr int nthreads_addbiasgrad_phase1 = 512;
93-
constexpr int nthreads_addbiasgrad_phase2 = 128;
94-
constexpr int threads_per_warp = 32;
95-
96-
inline int ceil_div(int x, int y) {
97-
return (x + y - 1) / y;
98-
}
99-
} // namespace
100-
101-
template <typename DType, typename LType>
102-
__global__ void add_bias_kernel(DType* mat, DType* bias, size_t lead_dim, size_t bias_length) {
103-
__shared__ LType scratch[nthreads_addbias * 2];
104-
const index_t N = bias_length * sizeof(DType)/sizeof(LType);
105-
const index_t base = blockIdx.x * N;
106-
LType* const mat_aligned = reinterpret_cast<LType*>(mat) + base;
107-
const LType* const bias_aligned = reinterpret_cast<LType*>(bias);
108-
LType* const scratch_bias_load = scratch + threadIdx.x;
109-
DType* const scratch_bias = reinterpret_cast<DType*>(scratch_bias_load);
110-
LType* const scratch_mat_load = scratch_bias_load + nthreads_addbias;
111-
DType* const scratch_mat = reinterpret_cast<DType*>(scratch_mat_load);
112-
for (index_t i = threadIdx.x; i < N; i += blockDim.x) {
113-
*scratch_bias_load = bias_aligned[i];
114-
*scratch_mat_load = mat_aligned[i];
115-
#pragma unroll
116-
for (int j = 0; j < sizeof(LType)/sizeof(DType); ++j) {
117-
scratch_mat[j] += scratch_bias[j];
118-
}
119-
mat_aligned[i] = *scratch_mat_load;
120-
}
121-
}
122-
123-
template<typename DType>
124-
void AddBias(Tensor<gpu, 1, DType> bias, Tensor<gpu, 2, DType> data,
125-
Tensor<gpu, 2, DType> out, Stream<gpu>* s) {
126-
int ltype = mxnet::common::cuda::get_load_type(bias.shape_[0] * sizeof(DType));
127-
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
128-
add_bias_kernel<DType, LType><<<data.size(0),
129-
nthreads_addbias,
130-
0,
131-
Stream<gpu>::GetStream(s)>>>(out.dptr_,
132-
bias.dptr_,
133-
data.size(0),
134-
bias.shape_[0]);
135-
});
136-
}
137-
138-
#endif // __CUDACC__
139-
14078
template<typename xpu, typename DType>
14179
void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
14280
const std::vector<TBlob> &in_data, const std::vector<OpReqType> &req,
@@ -184,153 +122,10 @@ void FCForward(const OpContext &ctx, const FullyConnectedParam &param,
184122
<< "Incomplete bias tensor detected: bias.data().shape[1] != weight.data().shape[0]."
185123
" This is not supported by FCForward. If bias is in row_sparse format, please"
186124
" make sure all row ids are present.";
187-
AddBias(bias, data, out, s);
125+
out += repmat(bias, data.size(0));
188126
}
189127
}
190128

191-
#if defined (__CUDACC__)
192-
193-
template<typename LType, typename DType, typename AType>
194-
__global__ void AddBiasGradKernelPhase1(AType * temp_space, const DType* grad,
195-
const size_t lead_dim, const size_t other_dim) {
196-
constexpr int num_warps = nthreads_addbiasgrad_phase1 / threads_per_warp;
197-
const int values_per_read = sizeof(LType) >= sizeof(DType) ? sizeof(LType) / sizeof(DType) : 1;
198-
const size_t stride = lead_dim / values_per_read;
199-
__shared__ AType scratch[threads_per_warp * num_warps * values_per_read];
200-
LType * my_scratch_load = &(reinterpret_cast<LType *>(scratch)[threadIdx.x]);
201-
DType * my_values_load = reinterpret_cast<DType *>(my_scratch_load);
202-
AType * my_values_acc = &(scratch[threadIdx.x * values_per_read]);
203-
AType acc[values_per_read]; // NOLINT(*)
204-
#pragma unroll
205-
for (int i = 0; i < values_per_read; ++i) {
206-
acc[i] = 0;
207-
}
208-
const size_t offset = blockIdx.x * threads_per_warp;
209-
const int my_warp = threadIdx.x / threads_per_warp;
210-
const int my_id = threadIdx.x % threads_per_warp;
211-
const LType* aligned_grad = reinterpret_cast<const LType*>(grad);
212-
const int rows_per_block = (other_dim + gridDim.y - 1) / gridDim.y;
213-
const size_t start_row = my_warp + rows_per_block * blockIdx.y;
214-
const size_t end_row = min(other_dim, static_cast<size_t>(rows_per_block * (blockIdx.y + 1)));
215-
if (offset + my_id < stride) {
216-
for (size_t i = start_row; i < end_row; i += num_warps) {
217-
*my_scratch_load = aligned_grad[i * stride + offset + my_id];
218-
#pragma unroll
219-
for (int j = 0; j < values_per_read; ++j) {
220-
acc[j] += static_cast<AType>(my_values_load[j]);
221-
}
222-
}
223-
}
224-
__syncthreads();
225-
#pragma unroll
226-
for (int i = 0; i < values_per_read; ++i) {
227-
my_values_acc[i] = acc[i];
228-
}
229-
230-
__syncthreads();
231-
232-
for (int i = num_warps / 2; i > 0; i /= 2) {
233-
if (my_warp < i) {
234-
const int shared_offset = values_per_read * i * threads_per_warp;
235-
#pragma unroll
236-
for (int j = 0; j < values_per_read; ++j) {
237-
my_values_acc[j] += my_values_acc[j + shared_offset];
238-
}
239-
}
240-
__syncthreads();
241-
}
242-
243-
if (threadIdx.x < min(threads_per_warp * values_per_read,
244-
static_cast<int>(lead_dim - values_per_read * offset))) {
245-
const size_t offset_out = values_per_read * offset +
246-
blockIdx.y * lead_dim;
247-
temp_space[offset_out + threadIdx.x] = scratch[threadIdx.x];
248-
}
249-
}
250-
251-
template <typename DType, typename AType>
252-
__global__ void AddBiasGradKernelPhase2(const AType * temp_space, DType * out,
253-
int lead_dim, int n_blocks, OpReqType req) {
254-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
255-
if (tid < lead_dim) {
256-
AType acc = 0;
257-
for (int i = tid; i < lead_dim * n_blocks; i += lead_dim) {
258-
acc += temp_space[i];
259-
}
260-
KERNEL_ASSIGN(out[tid], req, static_cast<DType>(acc));
261-
}
262-
}
263-
264-
template<typename DType>
265-
void AddBiasGrad(const TBlob& in_grad,
266-
Tensor<gpu, 2, DType> grad,
267-
OpReqType req,
268-
int num_hidden,
269-
const OpContext& ctx) {
270-
if (req == kNullOp) return;
271-
using AType = typename mxnet_op::AccType<DType>::type;
272-
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
273-
Tensor<gpu, 1, DType> gbias = in_grad.get<gpu, 1, DType>(s);
274-
TBlob grad_blob = TBlob(grad);
275-
TBlob gbias_blob = TBlob(gbias);
276-
mxnet::TShape x(1, 0);
277-
mxnet::TShape small;
278-
if (shape_assign(&gbias_blob.shape_, Shape2(num_hidden, 1))) {
279-
small = gbias_blob.shape_;
280-
} else {
281-
small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional<mxnet::TShape>(x), true, false);
282-
}
283-
const int N = small.Size();
284-
int ltype = mxnet::common::cuda::get_load_type(N * sizeof(DType));
285-
const int M = grad_blob.shape_.Size() / N;
286-
MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
287-
const unsigned int blocks_x = ceil_div(N * sizeof(DType),
288-
threads_per_warp * sizeof(LType));
289-
const unsigned int preferred_number_of_blocks = 2 *
290-
MultiprocessorCount(ctx.run_ctx.ctx.dev_id);
291-
const unsigned int blocks_y = std::max(preferred_number_of_blocks / blocks_x, 1u);
292-
const dim3 n_blocks = {blocks_x, blocks_y, 1};
293-
auto scratch_space = ctx.requested[fullc::kTempSpace]
294-
.get_space_typed<gpu, 1, AType>(mshadow::Shape1(N * blocks_y), s);
295-
auto stream = mshadow::Stream<gpu>::GetStream(s);
296-
AddBiasGradKernelPhase1<LType><<<n_blocks,
297-
nthreads_addbiasgrad_phase1,
298-
0,
299-
stream>>>(scratch_space.dptr_,
300-
grad.dptr_, N, M);
301-
const int nblocks_phase2 = ceil_div(N, nthreads_addbiasgrad_phase2);
302-
AddBiasGradKernelPhase2<<<nblocks_phase2,
303-
nthreads_addbiasgrad_phase2,
304-
0,
305-
stream>>>(scratch_space.dptr_,
306-
gbias.dptr_, N,
307-
blocks_y, req);
308-
});
309-
}
310-
#endif
311-
312-
template<typename DType>
313-
void AddBiasGrad(const TBlob& in_grad,
314-
Tensor<cpu, 2, DType> grad,
315-
OpReqType req,
316-
int num_hidden,
317-
const OpContext& ctx) {
318-
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
319-
Tensor<cpu, 1, DType> gbias = in_grad.get<cpu, 1, DType>(s);
320-
TBlob grad_blob = TBlob(grad);
321-
TBlob gbias_blob = TBlob(gbias);
322-
mxnet::TShape x(1, 0);
323-
mxnet::TShape small;
324-
if (shape_assign(&gbias_blob.shape_, Shape2(num_hidden, 1))) {
325-
small = gbias_blob.shape_;
326-
} else {
327-
small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional<mxnet::TShape>(x), true, false);
328-
}
329-
ReduceAxesComputeImpl<cpu, mshadow::red::sum, false, false,
330-
mshadow_op::identity>(ctx, {grad_blob}, {req},
331-
{in_grad}, small);
332-
}
333-
334129
template<typename xpu, typename DType>
335130
void FCBackward(const OpContext &ctx, const FullyConnectedParam &param,
336131
const std::vector<TBlob> &out_grad, const std::vector<TBlob> &in_data,
@@ -374,7 +169,19 @@ void FCBackward(const OpContext &ctx, const FullyConnectedParam &param,
374169
linalg_gemm(grad, data, gwmat, true, false, s, req[fullc::kWeight]);
375170
// gradient of bias
376171
if (!param.no_bias) {
377-
AddBiasGrad(in_grad[fullc::kBias], grad, req[fullc::kBias], param.num_hidden, ctx);
172+
Tensor<xpu, 1, DType> gbias = in_grad[fullc::kBias].get<xpu, 1, DType>(s);
173+
TBlob grad_blob = TBlob(grad);
174+
TBlob gbias_blob = TBlob(gbias);
175+
mxnet::TShape x(1, 0);
176+
mxnet::TShape small;
177+
if (shape_assign(&gbias_blob.shape_, Shape2(param.num_hidden, 1))) {
178+
small = gbias_blob.shape_;
179+
} else {
180+
small = ReduceAxesShapeImpl(grad_blob.shape_, dmlc::optional<mxnet::TShape>(x), true, false);
181+
}
182+
ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, false,
183+
mshadow_op::identity>(ctx, {grad_blob}, {req[fullc::kBias]},
184+
{in_grad[fullc::kBias]}, small);
378185
}
379186
// gradient of data
380187
// Legacy approach shown here for comparison:

0 commit comments

Comments
 (0)