Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 2d58ff5

Browse files
authored
[Bug Fixed] Fix batch norm when grad_req is add (#18500)
* fix batch norm when fix_gamma is True * support gradient accumulation for batch norm * mkldnn batchnorm support grad add * unittest for bn * fix bn arg * fix lint * fix mkldnn * fix mkldnn bn * fix grad when fixing gamma * fix naive gpu bn * fix lint * fix cudnn bn * fix flag * fix lint * fix testcase * fix * use @pytest.mark.parametrize * combination * remove redundant test in batchnorm * npx.batch_norm test * try to fix test * reduce the number of tests for batchnorm * fix
1 parent 992ed3c commit 2d58ff5

File tree

7 files changed

+378
-83
lines changed

7 files changed

+378
-83
lines changed

src/operator/nn/batch_norm-inl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param,
259259
const std::vector<TBlob> &outputs) {
260260
CHECK_EQ(inputs.size(), 8U);
261261
CHECK_EQ(outputs.size(), 3U);
262+
262263
std::vector<TBlob> out_grad(1);
263264
std::vector<TBlob> out_data(3);
264265
std::vector<TBlob> in_data(3);

src/operator/nn/batch_norm.cc

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,31 @@ static inline void ForEachFast(const BNTensor3<DType1> &in_data,
8585
}
8686
}
8787

88+
template<typename DType1, typename DType2, typename DType3, typename OnData>
89+
static inline void ForEachFast(const BNTensor3<DType1> &in_data,
90+
const BNTensor3<DType2> &in_data2,
91+
const BNTensor3<DType3> &out_data,
92+
const size_t channel,
93+
OnData onData) {
94+
const size_t num = in_data.OuterSize();
95+
const size_t matrixSize = in_data.InnerSize();
96+
const size_t skipLength = in_data.SkipLengthToNextSameChannelData();
97+
const size_t startOffset = in_data.StartOffset(channel);
98+
99+
DType1 *data = in_data.dptr_ + startOffset;
100+
DType2 *data2 = in_data2.dptr_ + startOffset;
101+
DType3 *odata = out_data.dptr_ + startOffset;
102+
103+
for (size_t outer = 0; outer < num; ++outer) {
104+
for (size_t i = 0; i < matrixSize; ++i) {
105+
onData(data++, data2++, odata++);
106+
}
107+
data += skipLength;
108+
data2 += skipLength;
109+
odata += skipLength;
110+
}
111+
}
112+
88113
} // namespace batchnorm
89114

90115
/*! \brief Forward CPU */
@@ -263,7 +288,7 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
263288
dotp += (*thisInputData - mean) * (*gradOut_data);
264289
});
265290

266-
if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) { // if there's a grad input
291+
if (!gradIn.IsEmpty() && req[batchnorm::kData] != kNullOp) { // if there's a grad input
267292
if (is_train_and_not_global_stats) {
268293
// when in training mode
269294
// Q(X) = X - E[x] ; i.e. input centered to zero mean
@@ -272,44 +297,60 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
272297

273298
// projection of gradOutput on to output scaled by std
274299
const AccReal k = dotp * invstd * invstd / itemCount;
275-
ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
276-
[&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
277-
*gradIn_data = (*inputDataPtr - mean) * k;
278-
});
279-
280300
const AccReal iw = invstd * w;
281301
const AccReal gradMean = sumGradOut / itemCount;
282-
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
283-
[iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
284-
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
285-
});
302+
if (req[batchnorm::kData] != kAddTo) {
303+
ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
304+
[&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
305+
*gradIn_data = (*inputDataPtr - mean) * k;
306+
});
307+
308+
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
309+
[iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
310+
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
311+
});
312+
} else {
313+
ForEachFast(inputData, gradOut, gradIn, static_cast<size_t>(channel),
314+
[&mean, &k, iw, gradMean](const DType *inputDataPtr,
315+
const DType *gradOut_data,
316+
DType *gradIn_data) {
317+
DType normal_val = (*inputDataPtr - mean) * k;
318+
*gradIn_data += (*gradOut_data - gradMean -
319+
normal_val) * iw;
320+
});
321+
}
286322
} else {
287323
// when in evaluation mode
288324
// Q(X) = X - running_mean ; i.e. input centered to zero mean
289325
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
290326
// dL/dX = w / running_std
291327
const AccReal iw = invstd * w;
292-
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
293-
[iw](const DType *gradOut_data, DType *gradIn_data) {
294-
*gradIn_data = *gradOut_data * iw;
295-
});
328+
if (req[batchnorm::kData] != kAddTo) {
329+
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
330+
[iw](const DType *gradOut_data, DType *gradIn_data) {
331+
*gradIn_data = *gradOut_data * iw;
332+
});
333+
} else {
334+
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
335+
[iw](const DType *gradOut_data, DType *gradIn_data) {
336+
*gradIn_data += *gradOut_data * iw;
337+
});
338+
}
296339
}
297340
}
298341

299342
// May want to make this a param eventually
300343
const AccReal scale = 1.0f;
301344

302-
if (IsBNWriting(req[batchnorm::kGamma])) {
303-
if (!param_.fix_gamma) {
304-
gradWeightData[channel] = scale * dotp * invstd;
305-
} else {
345+
if (!param_.fix_gamma) {
346+
KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd);
347+
} else {
348+
if (IsBNWriting(req[batchnorm::kGamma])) {
306349
gradWeightData[channel] = AccReal(0);
307350
}
308351
}
309352

310-
if (IsBNWriting(req[batchnorm::kBeta])) {
311-
gradBiasData[channel] = scale * sumGradOut;
312-
}
353+
KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut);
313354
}
314355
}
315356

src/operator/nn/batch_norm.cu

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
#define FIX_GAMMA_FLAG 8
3535
#define IS_TRAINING_FLAG 16
3636
#define USE_GLOBAL_STATS_FLAG 32
37+
#define ADDTO_DATA_FLAG (1 << 6)
38+
#define ADDTO_GAMMA_FLAG (1 << 7)
39+
#define ADDTO_BETA_FLAG (1 << 8)
3740

3841
#if MXNET_USE_CUDNN == 1
3942
#include "./cudnn/cudnn_batch_norm-inl.h"
@@ -361,33 +364,60 @@ static __global__ void BatchNormalizationBackwardKernel(
361364
* momentum + localVariance * (AccReal(1) - momentum);
362365
}
363366

364-
if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) {
365-
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
366-
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
367-
const DType gradOut = gradOutput.get_ref(batch, plane, x);
368-
if (is_train_and_not_global_stats) {
369-
const DType inp = input.get_ref(batch, plane, x);
370-
const AccReal proj = (inp - mean) * projScale;
371-
gradInput.get_ref(batch, plane, x) =
372-
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
373-
} else {
374-
gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
375-
gradOut * gradScale);
367+
if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) {
368+
const bool grad_write = flags & WRITE_DATA_FLAG;
369+
if (grad_write) {
370+
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
371+
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
372+
const DType gradOut = gradOutput.get_ref(batch, plane, x);
373+
if (is_train_and_not_global_stats) {
374+
const DType inp = input.get_ref(batch, plane, x);
375+
const AccReal proj = (inp - mean) * projScale;
376+
gradInput.get_ref(batch, plane, x) =
377+
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
378+
} else {
379+
gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
380+
gradOut * gradScale);
381+
}
382+
}
383+
}
384+
} else {
385+
// grad addto
386+
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
387+
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
388+
const DType gradOut = gradOutput.get_ref(batch, plane, x);
389+
if (is_train_and_not_global_stats) {
390+
const DType inp = input.get_ref(batch, plane, x);
391+
const AccReal proj = (inp - mean) * projScale;
392+
gradInput.get_ref(batch, plane, x) +=
393+
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
394+
} else {
395+
gradInput.get_ref(batch, plane, x) += ScalarConvert<AccReal, DType>::to(
396+
gradOut * gradScale);
397+
}
376398
}
377399
}
378400
}
379401
}
380402

381-
if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) {
403+
if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 &&
404+
(flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) {
382405
if ((flags & FIX_GAMMA_FLAG) == 0) {
383-
tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
406+
if (flags & WRITE_GAMMA_FLAG)
407+
tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
408+
else
409+
tensors.gradWeight[plane] += ScalarConvert<AccReal, DType>::to(dotP * invstd);
384410
} else {
385411
tensors.gradWeight[plane] = DType(0);
386412
}
387413
}
388414

389-
if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) {
390-
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
415+
if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 &&
416+
(flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) {
417+
if (flags & WRITE_BETA_FLAG)
418+
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
419+
else
420+
tensors.gradBias[plane] += ScalarConvert<AccReal, DType>::to(gradOutputSum);
391421
}
392422
}
393423

@@ -579,12 +609,18 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
579609
flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
580610
if (IsBNWriting(req[batchnorm::kData])) {
581611
flags |= WRITE_DATA_FLAG;
612+
} else if (req[batchnorm::kData] == kAddTo) {
613+
flags |= ADDTO_DATA_FLAG;
582614
}
583615
if (IsBNWriting(req[batchnorm::kGamma])) {
584616
flags |= WRITE_GAMMA_FLAG;
617+
} else if (req[batchnorm::kGamma] == kAddTo) {
618+
flags |= ADDTO_GAMMA_FLAG;
585619
}
586620
if (IsBNWriting(req[batchnorm::kBeta])) {
587621
flags |= WRITE_BETA_FLAG;
622+
} else if (req[batchnorm::kBeta] == kAddTo) {
623+
flags |= ADDTO_BETA_FLAG;
588624
}
589625
return flags;
590626
}

src/operator/nn/cudnn/cudnn_batch_norm-inl.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,24 @@ class CuDNNBatchNormOp {
222222

223223
if (param_.fix_gamma) gamma = 1.f;
224224

225+
bool grad_add_gamma_beta = (req[cudnnbatchnorm::kGamma] == kAddTo) ||
226+
(req[cudnnbatchnorm::kBeta] == kAddTo);
227+
if (grad_add_gamma_beta) {
228+
if (IsBNWriting(req[cudnnbatchnorm::kGamma])) {
229+
dgamma = 0.f;
230+
}
231+
if (IsBNWriting(req[cudnnbatchnorm::kBeta])) {
232+
dbeta = 0.f;
233+
}
234+
}
235+
225236
CUDNN_CALL(cudnnBatchNormalizationBackward(
226237
s->dnn_handle_,
227238
mode,
228239
&a,
229-
&b,
240+
req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b,
230241
&a,
231-
req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add,
242+
grad_add_gamma_beta ? &b_add : &b, // gamma and beta
232243
io_desc_,
233244
x.dptr_,
234245
io_desc_,

src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
347347
else if (diff.IsDefaultData())
348348
diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc());
349349
auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
350-
auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_desc());
350+
auto gradi_mem = CreateMKLDNNMem(const_cast<NDArray &>(gradIn),
351+
bwd.pd.diff_src_desc(), req[batchnorm::kData]);
351352

352353
if (static_cast<int>(flags) & static_cast<int>(mkldnn::normalization_flags::use_scale_shift)) {
353354
const NDArray &gamma = in_data[batchnorm::kGamma];
@@ -368,7 +369,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
368369
}
369370
mkldnn_args_map_t net_args;
370371
net_args[MKLDNN_ARG_SRC] = *data_mem;
371-
net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
372+
net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem.second;
372373
net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight();
373374
net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
374375
net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem;
@@ -401,28 +402,46 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
401402
}
402403
net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData());
403404
net_args[MKLDNN_ARG_VARIANCE] = var_mem;
404-
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
405-
MKLDNNStream::Get()->Submit();
406405
} else {
407406
net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData());
408407
net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData());
409-
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
410-
MKLDNNStream::Get()->Submit();
411408
}
409+
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
410+
CommitOutput(gradIn, gradi_mem);
411+
MKLDNNStream::Get()->Submit();
412412

413413
// copy data from gradw_mem to in_grad[1] and in_grad[2]
414414
DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
415-
DType *w_grad_1 = in_grad[1].data().dptr<DType>();
416-
DType *w_grad_2 = in_grad[2].data().dptr<DType>();
415+
DType *w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<DType>();
416+
DType *w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<DType>();
417417

418+
// the gradient of gamma
418419
if (!param.fix_gamma) {
419-
memcpy(w_grad_1, gw_buf, copy_size);
420-
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
420+
if (req[batchnorm::kGamma] != kNullOp) {
421+
if (req[batchnorm::kGamma] != kAddTo) {
422+
memcpy(w_grad_1, gw_buf, copy_size);
423+
} else {
424+
for (int i = 0; i < channels_; i++) {
425+
w_grad_1[i] += gw_buf[i];
426+
}
427+
}
428+
}
421429
} else {
422430
for (int i = 0; i < channels_; i++) {
423431
(in_grad[1].data().dptr<DType>())[i] = 0.0f;
424432
}
425-
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
433+
}
434+
435+
// the gradient of beta
436+
if (req[batchnorm::kBeta] != kNullOp) {
437+
if (req[batchnorm::kBeta] != kAddTo) {
438+
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
439+
} else {
440+
DType *grad_beta = &gw_buf[channels_];
441+
for (int i = 0; i < channels_; i++) {
442+
w_grad_2[i] += grad_beta[i];
443+
}
444+
}
426445
}
427446
} else {
428447
LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";

0 commit comments

Comments
 (0)