Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit bd67723

Browse files
ptrendxDickJC123
authored andcommitted
Fix operators lying about their number of inputs (#17049)
* Add a check for number of inputs * Fix num inputs for backward_Deconvolution * Fix number of inputs to backward ROIAlign * Fix number of inputs to backward_SoftmaxOutput * Fix more operators lying about their number of inputs * Fix input number of backward NMS * Fixes * Fix dropout, RNN and upsampling backward number of inputs * Fix LeakyRelu number of inputs * Actually fix LeakyRelu * Fix pooling and concat * Fix Concat (attempt 2) * Fix from review * Incorporate Dick's changes * Add guard to MakeNonlossGradNode * Fix * Fix backward of SoftmaxActivation * Fix backward of np_prod and norm
1 parent be9e17e commit bd67723

21 files changed

+105
-30
lines changed

src/operator/contrib/bilinear_resize-inl.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,6 @@ inline uint16_t BilinearSampleOpNumInputs(const NodeAttrs& attrs) {
328328
}
329329
}
330330

331-
inline uint16_t BilinearSampleOpNumBackwardInputs(const NodeAttrs& attrs) {
332-
auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
333-
if (param.mode == bilinear_resize::like) {
334-
return 3;
335-
} else {
336-
return 1;
337-
}
338-
}
339-
340331
inline uint16_t BilinearSampleOpNumBackwardOutputs(const NodeAttrs& attrs) {
341332
auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
342333
if (param.mode == bilinear_resize::like) {

src/operator/contrib/bilinear_resize.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ for more details.
232232

233233
NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
234234
.set_attr_parser(ParamParser<BilinearSampleParam>)
235-
.set_num_inputs(BilinearSampleOpNumBackwardInputs)
235+
.set_num_inputs(1)
236236
.set_num_outputs(BilinearSampleOpNumBackwardOutputs)
237237
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
238238
.set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpBackward<cpu>);

src/operator/contrib/bounding_box.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ Examples::
110110
.add_arguments(BoxNMSParam::__FIELDS__());
111111

112112
NNVM_REGISTER_OP(_backward_contrib_box_nms)
113-
.set_num_inputs(3)
113+
.set_num_inputs(4)
114114
.set_num_outputs(1)
115115
.set_attr_parser(ParamParser<BoxNMSParam>)
116116
.set_attr<nnvm::TIsBackward>("TIsBackward", true)

src/operator/contrib/roi_align.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ He, Kaiming, et al. "Mask R-CNN." ICCV, 2017
621621

622622

623623
NNVM_REGISTER_OP(_backward_ROIAlign)
624+
.set_num_inputs(2)
624625
.set_num_outputs(2)
625626
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
626627
.set_attr_parser(ParamParser<ROIAlignParam>)

src/operator/custom/custom.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ Please check the tutorial here: https://mxnet.incubator.apache.org/api/faq/new_o
586586
NNVM_REGISTER_OP(_backward_Custom)
587587
.set_num_inputs([](const NodeAttrs& attrs){
588588
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);
589-
return params.bwd_idx.size();
589+
return params.bwd_idx.size() + params.num_auxs;
590590
})
591591
.set_num_outputs([](const NodeAttrs& attrs){
592592
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);

src/operator/image/image_random.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ NNVM_REGISTER_OP(_image_normalize)
185185

186186
NNVM_REGISTER_OP(_backward_image_normalize)
187187
.set_attr_parser(ParamParser<NormalizeParam>)
188-
.set_num_inputs(1)
188+
.set_num_inputs(2)
189189
.set_num_outputs(1)
190190
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
191191
.set_attr<FCompute>("FCompute<cpu>", NormalizeOpBackward<cpu>);

src/operator/leaky_relu.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,19 @@ The following modified ReLU Activation functions are supported:
206206
});
207207

208208
NNVM_REGISTER_OP(_backward_LeakyReLU)
209+
.set_num_inputs([](const NodeAttrs& attrs) {
210+
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
211+
if (param.act_type == leakyrelu::kPReLU) {
212+
// forward has 2 inputs and 1 output
213+
return 2 + 2 * 1;
214+
} else if (param.act_type == leakyrelu::kRReLU) {
215+
// forward has 1 input and 2 outputs
216+
return 1 + 2 * 2;
217+
} else {
218+
// forward has 1 input and 1 output
219+
return 1 + 2 * 1;
220+
}
221+
})
209222
.set_num_outputs([](const NodeAttrs& attrs) {
210223
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
211224
return param.act_type == leakyrelu::kPReLU ? 2 : 1;

src/operator/nn/batch_norm.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ then set ``gamma`` to 1 and its gradient to 0.
593593
});
594594

595595
NNVM_REGISTER_OP(_backward_BatchNorm)
596+
.set_num_inputs(8)
596597
.set_num_outputs(3)
597598
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
598599
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)

src/operator/nn/concat.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,14 @@ CONCAT_FORWARD_ATTRS
396396
.add_arguments(ConcatParam::__FIELDS__());
397397

398398
NNVM_REGISTER_OP(_backward_Concat)
399+
.set_num_inputs([](const NodeAttrs& attrs) {
400+
#if MXNET_USE_MKLDNN == 1
401+
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
402+
return 1 + params.num_args;
403+
#else
404+
return 1;
405+
#endif
406+
})
399407
.set_num_outputs([](const NodeAttrs& attrs) {
400408
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
401409
return params.num_args;

src/operator/nn/convolution.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,10 @@ There are other options to tune the performance.
510510
.add_arguments(ConvolutionParam::__FIELDS__());
511511

512512
NNVM_REGISTER_OP(_backward_Convolution)
513+
.set_num_inputs([](const NodeAttrs& attrs) {
514+
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
515+
return params.no_bias ? 3 : 4;
516+
})
513517
.set_num_outputs([](const NodeAttrs& attrs) {
514518
const ConvolutionParam& params = nnvm::get<ConvolutionParam>(attrs.parsed);
515519
return params.no_bias ? 2 : 3;

src/operator/nn/ctc_loss.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ information on the definition and the algorithm.
130130

131131
NNVM_REGISTER_OP(_backward_ctc_loss)
132132
.set_attr_parser(ParamParser<CTCLossOpParam>)
133-
.set_num_inputs(1)
133+
.set_num_inputs(4)
134134
.set_num_outputs(CTCLossOpNumInputs)
135135
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
136136
.set_attr<FCompute>("FCompute<cpu>", CTCLossOpBackward<cpu>);

src/operator/nn/deconvolution.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,10 @@ NNVM_REGISTER_OP(Deconvolution)
445445
.add_arguments(DeconvolutionParam::__FIELDS__());
446446

447447
NNVM_REGISTER_OP(_backward_Deconvolution)
448+
.set_num_inputs([](const NodeAttrs& attrs) {
449+
const DeconvolutionParam& params = nnvm::get<DeconvolutionParam>(attrs.parsed);
450+
return params.no_bias ? 3 : 4;
451+
})
448452
.set_num_outputs([](const NodeAttrs& attrs) {
449453
const DeconvolutionParam& params = nnvm::get<DeconvolutionParam>(attrs.parsed);
450454
return params.no_bias ? 2 : 3;

src/operator/nn/lrn.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ number of kernels in the layer.
184184
.add_arguments(LRNParam::__FIELDS__());
185185

186186
NNVM_REGISTER_OP(_backward_LRN)
187+
.set_num_inputs(3)
187188
.set_num_outputs(1)
188189
.set_attr_parser(ParamParser<LRNParam>)
189190
#if MXNET_USE_MKLDNN == 1

src/operator/nn/pooling.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,11 @@ For each window ``X``, the mathematical expression for Lp pooling is:
453453
.add_arguments(PoolingParam::__FIELDS__());
454454

455455
NNVM_REGISTER_OP(_backward_Pooling)
456+
.set_num_inputs([](const NodeAttrs& attrs) {
457+
const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
458+
// 1 input to fwd op and 2 * outputs from fwd op (fwd outputs and gradient inputs)
459+
return 1 + 2 * GetNumOutputs(param);
460+
})
456461
.set_num_outputs(1)
457462
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
458463
.set_attr<nnvm::FInplaceOption>(

src/operator/nn/softmax_activation.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Example::
6767
.add_arguments(SoftmaxActivationParam::__FIELDS__());
6868

6969
NNVM_REGISTER_OP(_backward_SoftmaxActivation)
70+
.set_num_inputs(2)
7071
.set_num_outputs(1)
7172
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
7273
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){

src/operator/nn/upsampling.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,14 @@ Example::
211211
});
212212

213213
NNVM_REGISTER_OP(_backward_UpSampling)
214+
.set_num_inputs([](const NodeAttrs& attrs) {
215+
const UpSamplingParam& param_ = nnvm::get<UpSamplingParam>(attrs.parsed);
216+
if (param_.sample_type != up_enum::kNearest) {
217+
return 3;
218+
} else {
219+
return 1;
220+
}
221+
})
214222
.set_num_outputs([](const NodeAttrs& attrs) {
215223
const UpSamplingParam& params = nnvm::get<UpSamplingParam>(attrs.parsed);
216224
return params.sample_type == up_enum::kNearest ? params.num_args : 2;

src/operator/numpy/np_broadcast_reduce_op_value.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ NNVM_REGISTER_OP(_np_prod)
238238
.set_attr<nnvm::FGradient>("FGradient", ReduceGrad{"_backward_np_prod"});
239239

240240
NNVM_REGISTER_OP(_backward_np_prod)
241-
.set_num_inputs(1)
241+
.set_num_inputs(3)
242242
.set_num_outputs(1)
243243
.set_attr_parser(ParamParser<NumpyReduceAxesParam>)
244244
.set_attr<nnvm::TIsBackward>("TIsBackward", true)

src/operator/operator_common.h

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,36 @@ inline bool dispatch_fallback(StorageTypeVector* stypes, DispatchMode* dispatch)
359359
return true;
360360
}
361361

362+
inline std::vector<nnvm::NodeEntry>CreateNodeEntries(
363+
nnvm::NodePtr pNode,
364+
const std::vector<nnvm::NodeEntry>* pOgrads = nullptr,
365+
const std::vector<nnvm::NodeEntry>* pInputs = nullptr) {
366+
if (pOgrads)
367+
pNode->inputs.insert(pNode->inputs.end(), pOgrads->begin(), pOgrads->end());
368+
369+
if (pInputs)
370+
pNode->inputs.insert(pNode->inputs.end(), pInputs->begin(), pInputs->end());
371+
372+
if (!pNode->is_variable()) {
373+
CHECK_EQ(pNode->num_inputs(), pNode->inputs.size())
374+
<< "Number of inputs to operator " << pNode->op()->name << " (" << pNode->num_inputs()
375+
<< ") does not match the actual number of inputs provided to operator "
376+
<< pNode->attrs.name << " (" << pNode->inputs.size() << ").";
377+
}
378+
379+
std::vector<nnvm::NodeEntry> ret;
380+
for (uint32_t i = 0; i < pNode->num_outputs(); ++i)
381+
ret.emplace_back(nnvm::NodeEntry{pNode, i, 0});
382+
383+
return ret;
384+
}
385+
362386
// make a new node with operator op_name. Inputs are not filled.
363387
inline nnvm::NodePtr MakeNode(
364388
const char* op_name, const std::string& name,
365-
std::vector<nnvm::NodeEntry> const * inputs,
366-
std::unordered_map<std::string, std::string> const * dict,
367-
nnvm::NodePtr const * fwd_node) {
389+
std::vector<nnvm::NodeEntry> const * inputs = nullptr,
390+
std::unordered_map<std::string, std::string> const * dict = nullptr,
391+
nnvm::NodePtr const * fwd_node = nullptr) {
368392
auto p = nnvm::Node::Create();
369393
p->attrs.op = nnvm::Op::Get(op_name);
370394
p->attrs.name = name;
@@ -376,6 +400,12 @@ inline nnvm::NodePtr MakeNode(
376400
if (p->op()->attr_parser != nullptr) {
377401
p->op()->attr_parser(&(p->attrs));
378402
}
403+
if (inputs != nullptr) {
404+
CHECK_EQ(p->num_inputs(), p->inputs.size())
405+
<< "Number of inputs to operator " << op_name << " (" << p->num_inputs()
406+
<< ") does not match the actual number of inputs provided to operator "
407+
<< name << " (" << p->inputs.size() << ").";
408+
}
379409
return p;
380410
}
381411

@@ -395,11 +425,8 @@ inline std::vector<nnvm::NodeEntry> MakeGradNode(
395425
const std::unordered_map<std::string, std::string>& dict) {
396426
auto p = MakeNode(op_name, n->attrs.name + "_backward",
397427
&inputs, &dict, &n);
398-
std::vector<nnvm::NodeEntry> ret;
399-
for (uint32_t i = 0; i < p->num_outputs(); ++i) {
400-
ret.emplace_back(p, i, 0);
401-
}
402-
return ret;
428+
429+
return CreateNodeEntries(p);
403430
}
404431

405432
// quick helper to make gradient nodes that simply pass back zero. could be used in output ops.
@@ -446,13 +473,8 @@ inline std::vector<nnvm::NodeEntry> MakeNonlossGradNode(
446473
return MakeZeroGradNodes(n, ograds);
447474
auto p = MakeNode(op_name, n->attrs.name + "_backward",
448475
nullptr, &dict, &n);
449-
p->inputs.insert(p->inputs.end(), ograds.begin(), ograds.end());
450-
p->inputs.insert(p->inputs.end(), inputs.begin(), inputs.end());
451-
std::vector<nnvm::NodeEntry> ret;
452-
for (uint32_t i = 0; i < p->num_outputs(); ++i) {
453-
ret.emplace_back(p, i, 0);
454-
}
455-
return ret;
476+
477+
return CreateNodeEntries(p, &ograds, &inputs);
456478
}
457479

458480
/*! \brief Parse keyword arguments as PType arguments and save to parsed */

src/operator/rnn.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,20 @@ The definition of GRU here is slightly different from paper but compatible with
406406
.add_arguments(RNNParam::__FIELDS__());
407407

408408
NNVM_REGISTER_OP(_backward_RNN)
409+
.set_num_inputs([](const NodeAttrs& attrs) {
410+
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
411+
int ret = 5;
412+
if (params.state_outputs) {
413+
ret += 2;
414+
}
415+
if (params.mode == rnn_enum::kLstm) {
416+
++ret;
417+
if (params.state_outputs) {
418+
ret += 2;
419+
}
420+
}
421+
return ret;
422+
})
409423
.set_num_outputs([](const NodeAttrs& attrs) {
410424
const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
411425
return GetNumInputArguments(params);

src/operator/softmax_output.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ NNVM_REGISTER_OP(SoftmaxOutput)
258258
NNVM_REGISTER_OP(SoftmaxOutput).add_alias("Softmax");
259259

260260
NNVM_REGISTER_OP(_backward_SoftmaxOutput)
261+
.set_num_inputs(2)
261262
.set_num_outputs(2)
262263
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
263264
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){

src/operator/tensor/broadcast_reduce_norm_value.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Examples::
105105
.add_arguments(NormParam::__FIELDS__());
106106

107107
NNVM_REGISTER_OP(_backward_norm)
108+
.set_num_inputs(3)
108109
.set_num_outputs(1)
109110
.set_attr_parser(ParamParser<NormParam>)
110111
.set_attr<nnvm::TIsBackward>("TIsBackward", true)

0 commit comments

Comments
 (0)