Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Integrate MKL-DNN leakyrelu #16075

Merged
merged 9 commits into from
Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 38 additions & 154 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,166 +332,50 @@ class LeakyReLUOp : public Operator {
}; // class LeakyReLUOp

template<typename xpu>
Operator* CreateOp(LeakyReLUParam type, int dtype);
void LeakyReLUCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const LeakyReLUParam &param = nnvm::get<LeakyReLUParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
CHECK_EQ(inputs.size(), expected);

#if DMLC_USE_CXX11
class LeakyReLUProp : public OperatorProperty {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}

std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}

bool InferShape(mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape,
mxnet::ShapeVector *aux_shape) const override {
using namespace mshadow;
if (param_.act_type == leakyrelu::kPReLU) {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, gamma]";
} else {
CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
}
const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData);
if (!mxnet::ndim_is_known(dshape)) return false;
if (param_.act_type == leakyrelu::kPReLU) {
const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma);
if (!mxnet::ndim_is_known(gshape)) {
in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1]));
}
if (dshape == gshape) {
SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape);
}
}
out_shape->clear();
out_shape->push_back(dshape);
if (param_.act_type == leakyrelu::kRReLU) {
out_shape->push_back(dshape);
}
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
int dtype = -1;
for (const int& type : *in_type) {
type_assign(&dtype, type);
}
for (const int& type : *out_type) {
type_assign(&dtype, type);
}

for (size_t i = 0; i < in_type->size(); ++i) {
TYPE_ASSIGN_CHECK(*in_type, i, dtype);
}
for (size_t i = 0; i < out_type->size(); ++i) {
TYPE_ASSIGN_CHECK(*out_type, i, dtype);
}
return dtype != -1;
}

OperatorProperty* Copy() const override {
auto ptr = new LeakyReLUProp();
ptr->param_ = param_;
return ptr;
}

std::string TypeString() const override {
return "LeakyReLU";
}

// decalre dependency and inplace optimization options
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
if (param_.act_type == leakyrelu::kPReLU) {
return {out_grad[leakyrelu::kOut],
out_data[leakyrelu::kOut],
in_data[leakyrelu::kData],
in_data[leakyrelu::kGamma]};
} else if (param_.act_type == leakyrelu::kRReLU) {
return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kMask], out_data[leakyrelu::kOut]};
} else {
return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kData]};
}
}
MSHADOW_REAL_TYPE_SWITCH(inputs[leakyrelu::kData].type_flag_, DType, {
LeakyReLUOp<xpu, DType> op(param);
op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api);
});
}

std::vector<std::pair<int, void*> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<void*> &in_grad) const override {
return {{out_grad[leakyrelu::kOut], in_grad[leakyrelu::kData]}};
}

std::vector<std::pair<int, void*> > ForwardInplaceOption(
const std::vector<int> &in_data,
const std::vector<void*> &out_data) const override {
if (param_.act_type == leakyrelu::kPReLU) {
return {};
} else {
return {{in_data[leakyrelu::kData], out_data[leakyrelu::kOut]}};
}
}

std::vector<std::string> ListArguments() const override {
if (param_.act_type == leakyrelu::kPReLU) {
return {"data", "gamma"};
} else {
return {"data"};
}
}

std::vector<std::string> ListOutputs() const override {
if (param_.act_type == leakyrelu::kRReLU) {
return {"output", "mask"};
} else {
return {"output"};
}
}

int NumOutputs() const override {
if (param_.act_type == leakyrelu::kRReLU) {
return 2;
} else {
return 1;
}
}

int NumVisibleOutputs() const override {
return 1;
}

std::vector<ResourceRequest> ForwardResource(
const mxnet::ShapeVector &in_shape) const override {
if (param_.act_type == leakyrelu::kRReLU) {
return {ResourceRequest::kRandom};
} else {
return std::vector<ResourceRequest>();
}
}
template<typename xpu>
void LeakyReLUGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
// inputs: out_grad, input_data, input_gamma, output, output_mask
size_t expected_in = param.act_type == leakyrelu::kPReLU ? 2 : 1;
size_t expected_out = param.act_type == leakyrelu::kRReLU ? 2 : 1;

std::vector<ResourceRequest> BackwardResource(
const mxnet::ShapeVector &in_shape) const override {
return {ResourceRequest::kTempSpace};
}
CHECK_GE(inputs.size(), 1 + expected_in + expected_out);
std::vector<TBlob> out_grad{inputs[0]};
std::vector<TBlob> in_data(inputs.begin() + 1,
inputs.begin() + 1 + expected_in);
std::vector<TBlob> out_data(inputs.begin() + 1 + expected_in,
inputs.begin() + 1 + expected_in + expected_out);

Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
}
CHECK_EQ(req.size(), outputs.size());
int dtype = inputs[0].type_flag_;
const std::vector<TBlob> &in_grad = outputs;

Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
std::vector<int> *in_type) const override;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
LeakyReLUOp<xpu, DType> op(param);
op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api);
});
}

private:
LeakyReLUParam param_;
};
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet

Expand Down
Loading