Skip to content

Commit fbb1a92

Browse files
xinyu-inteldrivanov
authored andcommitted
Integrate MKL-DNN leakyrelu (apache#16075)
* add mkldnn leakyrelu support * improve mkldnn act param * register gpu path * remove old code * trigger * fix lint and improve backward function
1 parent a98a988 commit fbb1a92

File tree

12 files changed

+400
-216
lines changed

12 files changed

+400
-216
lines changed

src/operator/leaky_relu-inl.h

Lines changed: 38 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -332,166 +332,50 @@ class LeakyReLUOp : public Operator {
332332
}; // class LeakyReLUOp
333333

334334
template<typename xpu>
335-
Operator* CreateOp(LeakyReLUParam type, int dtype);
335+
void LeakyReLUCompute(const nnvm::NodeAttrs& attrs,
336+
const OpContext& ctx, const std::vector<TBlob>& inputs,
337+
const std::vector<OpReqType>& req,
338+
const std::vector<TBlob>& outputs) {
339+
const LeakyReLUParam &param = nnvm::get<LeakyReLUParam>(attrs.parsed);
340+
const std::vector<TBlob> no_use_but_adapt_origin_api;
341+
size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
342+
CHECK_EQ(inputs.size(), expected);
336343

337-
#if DMLC_USE_CXX11
338-
class LeakyReLUProp : public OperatorProperty {
339-
public:
340-
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
341-
param_.Init(kwargs);
342-
}
343-
344-
std::map<std::string, std::string> GetParams() const override {
345-
return param_.__DICT__();
346-
}
347-
348-
bool InferShape(mxnet::ShapeVector *in_shape,
349-
mxnet::ShapeVector *out_shape,
350-
mxnet::ShapeVector *aux_shape) const override {
351-
using namespace mshadow;
352-
if (param_.act_type == leakyrelu::kPReLU) {
353-
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, gamma]";
354-
} else {
355-
CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
356-
}
357-
const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData);
358-
if (!mxnet::ndim_is_known(dshape)) return false;
359-
if (param_.act_type == leakyrelu::kPReLU) {
360-
const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma);
361-
if (!mxnet::ndim_is_known(gshape)) {
362-
in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1]));
363-
}
364-
if (dshape == gshape) {
365-
SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape);
366-
}
367-
}
368-
out_shape->clear();
369-
out_shape->push_back(dshape);
370-
if (param_.act_type == leakyrelu::kRReLU) {
371-
out_shape->push_back(dshape);
372-
}
373-
return true;
374-
}
375-
376-
bool InferType(std::vector<int> *in_type,
377-
std::vector<int> *out_type,
378-
std::vector<int> *aux_type) const override {
379-
int dtype = -1;
380-
for (const int& type : *in_type) {
381-
type_assign(&dtype, type);
382-
}
383-
for (const int& type : *out_type) {
384-
type_assign(&dtype, type);
385-
}
386-
387-
for (size_t i = 0; i < in_type->size(); ++i) {
388-
TYPE_ASSIGN_CHECK(*in_type, i, dtype);
389-
}
390-
for (size_t i = 0; i < out_type->size(); ++i) {
391-
TYPE_ASSIGN_CHECK(*out_type, i, dtype);
392-
}
393-
return dtype != -1;
394-
}
395-
396-
OperatorProperty* Copy() const override {
397-
auto ptr = new LeakyReLUProp();
398-
ptr->param_ = param_;
399-
return ptr;
400-
}
401-
402-
std::string TypeString() const override {
403-
return "LeakyReLU";
404-
}
405-
406-
// decalre dependency and inplace optimization options
407-
std::vector<int> DeclareBackwardDependency(
408-
const std::vector<int> &out_grad,
409-
const std::vector<int> &in_data,
410-
const std::vector<int> &out_data) const override {
411-
if (param_.act_type == leakyrelu::kPReLU) {
412-
return {out_grad[leakyrelu::kOut],
413-
out_data[leakyrelu::kOut],
414-
in_data[leakyrelu::kData],
415-
in_data[leakyrelu::kGamma]};
416-
} else if (param_.act_type == leakyrelu::kRReLU) {
417-
return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kMask], out_data[leakyrelu::kOut]};
418-
} else {
419-
return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kData]};
420-
}
421-
}
344+
MSHADOW_REAL_TYPE_SWITCH(inputs[leakyrelu::kData].type_flag_, DType, {
345+
LeakyReLUOp<xpu, DType> op(param);
346+
op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api);
347+
});
348+
}
422349

423-
std::vector<std::pair<int, void*> > BackwardInplaceOption(
424-
const std::vector<int> &out_grad,
425-
const std::vector<int> &in_data,
426-
const std::vector<int> &out_data,
427-
const std::vector<void*> &in_grad) const override {
428-
return {{out_grad[leakyrelu::kOut], in_grad[leakyrelu::kData]}};
429-
}
430-
431-
std::vector<std::pair<int, void*> > ForwardInplaceOption(
432-
const std::vector<int> &in_data,
433-
const std::vector<void*> &out_data) const override {
434-
if (param_.act_type == leakyrelu::kPReLU) {
435-
return {};
436-
} else {
437-
return {{in_data[leakyrelu::kData], out_data[leakyrelu::kOut]}};
438-
}
439-
}
440-
441-
std::vector<std::string> ListArguments() const override {
442-
if (param_.act_type == leakyrelu::kPReLU) {
443-
return {"data", "gamma"};
444-
} else {
445-
return {"data"};
446-
}
447-
}
448-
449-
std::vector<std::string> ListOutputs() const override {
450-
if (param_.act_type == leakyrelu::kRReLU) {
451-
return {"output", "mask"};
452-
} else {
453-
return {"output"};
454-
}
455-
}
456-
457-
int NumOutputs() const override {
458-
if (param_.act_type == leakyrelu::kRReLU) {
459-
return 2;
460-
} else {
461-
return 1;
462-
}
463-
}
464-
465-
int NumVisibleOutputs() const override {
466-
return 1;
467-
}
468-
469-
std::vector<ResourceRequest> ForwardResource(
470-
const mxnet::ShapeVector &in_shape) const override {
471-
if (param_.act_type == leakyrelu::kRReLU) {
472-
return {ResourceRequest::kRandom};
473-
} else {
474-
return std::vector<ResourceRequest>();
475-
}
476-
}
350+
template<typename xpu>
351+
void LeakyReLUGradCompute(const nnvm::NodeAttrs& attrs,
352+
const OpContext& ctx,
353+
const std::vector<TBlob>& inputs,
354+
const std::vector<OpReqType>& req,
355+
const std::vector<TBlob>& outputs) {
356+
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
357+
const std::vector<TBlob> no_use_but_adapt_origin_api;
358+
// inputs: out_grad, input_data, input_gamma, output, output_mask
359+
size_t expected_in = param.act_type == leakyrelu::kPReLU ? 2 : 1;
360+
size_t expected_out = param.act_type == leakyrelu::kRReLU ? 2 : 1;
477361

478-
std::vector<ResourceRequest> BackwardResource(
479-
const mxnet::ShapeVector &in_shape) const override {
480-
return {ResourceRequest::kTempSpace};
481-
}
362+
CHECK_GE(inputs.size(), 1 + expected_in + expected_out);
363+
std::vector<TBlob> out_grad{inputs[0]};
364+
std::vector<TBlob> in_data(inputs.begin() + 1,
365+
inputs.begin() + 1 + expected_in);
366+
std::vector<TBlob> out_data(inputs.begin() + 1 + expected_in,
367+
inputs.begin() + 1 + expected_in + expected_out);
482368

483-
Operator* CreateOperator(Context ctx) const override {
484-
LOG(FATAL) << "Not Implemented.";
485-
return NULL;
486-
}
369+
CHECK_EQ(req.size(), outputs.size());
370+
int dtype = inputs[0].type_flag_;
371+
const std::vector<TBlob> &in_grad = outputs;
487372

488-
Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
489-
std::vector<int> *in_type) const override;
373+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
374+
LeakyReLUOp<xpu, DType> op(param);
375+
op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api);
376+
});
377+
}
490378

491-
private:
492-
LeakyReLUParam param_;
493-
};
494-
#endif // DMLC_USE_CXX11
495379
} // namespace op
496380
} // namespace mxnet
497381

0 commit comments

Comments
 (0)