Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[RFC][DTypes] pooling and LeakyReLU #2280

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/operator/cudnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
namespace mxnet {
namespace op {

template<typename DType>
class CuDNNPoolingOp : public Operator {
public:
explicit CuDNNPoolingOp(PoolingParam p) {
param_ = p;
init_cudnn_ = false;
// TODO(xxx): fp16
dtype_ = CUDNN_DATA_FLOAT;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
switch (param_.pool_type) {
case pool_enum::kMaxPooling:
mode_ = CUDNN_POOLING_MAX;
Expand Down Expand Up @@ -51,8 +52,8 @@ class CuDNNPoolingOp : public Operator {
CHECK_EQ(in_data.size(), 1);
CHECK_EQ(out_data.size(), 1);
Stream<gpu> *s = ctx.get_stream<gpu>();
Tensor<gpu, 4> data = in_data[pool_enum::kData].get<gpu, 4, real_t>(s);
Tensor<gpu, 4> out = out_data[pool_enum::kOut].get<gpu, 4, real_t>(s);
Tensor<gpu, 4, DType> data = in_data[pool_enum::kData].get<gpu, 4, DType>(s);
Tensor<gpu, 4, DType> out = out_data[pool_enum::kOut].get<gpu, 4, DType>(s);
CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
if (!init_cudnn_) {
this->Init(s, in_data, out_data);
Expand Down Expand Up @@ -90,10 +91,10 @@ class CuDNNPoolingOp : public Operator {
CHECK_EQ(in_grad.size(), 1);

Stream<gpu> *s = ctx.get_stream<gpu>();
Tensor<gpu, 4> m_out_grad = out_grad[pool_enum::kOut].get<gpu, 4, real_t>(s);
Tensor<gpu, 4> m_in_data = in_data[pool_enum::kData].get<gpu, 4, real_t>(s);
Tensor<gpu, 4> m_out_data = out_data[pool_enum::kOut].get<gpu, 4, real_t>(s);
Tensor<gpu, 4> m_in_grad = in_grad[pool_enum::kData].get<gpu, 4, real_t>(s);
Tensor<gpu, 4, DType> m_out_grad = out_grad[pool_enum::kOut].get<gpu, 4, DType>(s);
Tensor<gpu, 4, DType> m_in_data = in_data[pool_enum::kData].get<gpu, 4, DType>(s);
Tensor<gpu, 4, DType> m_out_data = out_data[pool_enum::kOut].get<gpu, 4, DType>(s);
Tensor<gpu, 4, DType> m_in_grad = in_grad[pool_enum::kData].get<gpu, 4, DType>(s);
CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream<gpu>::OwnHandle);
float alpha = 1.0f;
float beta = 0.0f;
Expand Down Expand Up @@ -148,8 +149,8 @@ class CuDNNPoolingOp : public Operator {
CHECK_EQ(out_data.size(), 1);
if (!init_cudnn_) {
init_cudnn_ = true;
Tensor<gpu, 4> data = in_data[pool_enum::kData].get<gpu, 4, real_t>(s);
Tensor<gpu, 4> out = out_data[pool_enum::kOut].get<gpu, 4, real_t>(s);
Tensor<gpu, 4, DType> data = in_data[pool_enum::kData].get<gpu, 4, DType>(s);
Tensor<gpu, 4, DType> out = out_data[pool_enum::kOut].get<gpu, 4, DType>(s);
CHECK_EQ(cudnnCreatePoolingDescriptor(&pooling_desc_), CUDNN_STATUS_SUCCESS);
CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS);
CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS);
Expand Down
99 changes: 57 additions & 42 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
};

struct prelu_grad {
MSHADOW_XINLINE static real_t Map(real_t a) {
return a > 0.0f ? 0.0f : a;
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(a > DType(0.0f) ? DType(0.0f) : a);
}
};

template<typename xpu>
template<typename xpu, typename DType>
class LeakyReLUOp : public Operator {
public:
explicit LeakyReLUOp(LeakyReLUParam param) {
Expand All @@ -73,50 +74,56 @@ class LeakyReLUOp : public Operator {
size_t expected = param_.act_type == leakyrelu::kPReLU ? 2 : 1;
CHECK_EQ(in_data.size(), expected);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> data;
Tensor<xpu, 4> out;
Tensor<xpu, 4> mask;
Tensor<xpu, 1> weight;
Tensor<xpu, 4, DType> data;
Tensor<xpu, 4, DType> out;
Tensor<xpu, 4, DType> mask;
Tensor<xpu, 1, DType> weight;
if (in_data[leakyrelu::kData].ndim() == 2) {
Shape<4> dshape = Shape4(in_data[leakyrelu::kData].shape_[0],
in_data[leakyrelu::kData].shape_[1], 1, 1);
data = in_data[leakyrelu::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
out = out_data[leakyrelu::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
data = in_data[leakyrelu::kData].get_with_shape<xpu, 4, DType>(dshape, s);
out = out_data[leakyrelu::kOut].get_with_shape<xpu, 4, DType>(dshape, s);
if (param_.act_type == leakyrelu::kRReLU) {
mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 4, real_t>(dshape, s);
mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 4, DType>(dshape, s);
}
} else {
data = in_data[leakyrelu::kData].get<xpu, 4, real_t>(s);
out = out_data[leakyrelu::kOut].get<xpu, 4, real_t>(s);
data = in_data[leakyrelu::kData].get<xpu, 4, DType>(s);
out = out_data[leakyrelu::kOut].get<xpu, 4, DType>(s);
if (param_.act_type == leakyrelu::kRReLU) {
mask = out_data[leakyrelu::kMask].get<xpu, 4, real_t>(s);
mask = out_data[leakyrelu::kMask].get<xpu, 4, DType>(s);
}
}
switch (param_.act_type) {
case leakyrelu::kLeakyReLU: {
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, param_.slope));
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, slope));
break;
}
case leakyrelu::kPReLU: {
weight = in_data[leakyrelu::kGamma].get<xpu, 1, real_t>(s);
weight = in_data[leakyrelu::kGamma].get<xpu, 1, DType>(s);
Assign(out, req[leakyrelu::kOut],
F<mshadow_op::xelu>(data, broadcast<1>(weight, out.shape_)));
break;
}
case leakyrelu::kRReLU: {
if (ctx.is_train) {
// TODO(vchuravy): Random doesn't work with Float16, this will lead to a reduced
// entropy for Float64.
Random<xpu>* prnd = ctx.requested[leakyrelu::kRandom].get_random<xpu, real_t>(s);
mask = prnd->uniform(mask.shape_);
mask = mask * (param_.upper_bound - param_.lower_bound) + param_.lower_bound;
mask = tcast<DType>(prnd->uniform(mask.shape_));
mask = mask * ScalarExp<DType>(param_.upper_bound - param_.lower_bound)
+ ScalarExp<DType>(param_.lower_bound);
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, mask));
} else {
const float slope = (param_.lower_bound + param_.upper_bound) / 2.0f;
ScalarExp<DType> slope =
ScalarExp<DType>((param_.lower_bound + param_.upper_bound) / 2.0f);
Assign(out, req[leakyrelu::kOut], F<mshadow_op::xelu>(data, slope));
}
break;
}
case leakyrelu::kELU: {
Assign(out, req[leakyrelu::kOut], F<mshadow_op::elu>(data, param_.slope));
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
Assign(out, req[leakyrelu::kOut], F<mshadow_op::elu>(data, slope));
break;
}
default:
Expand All @@ -138,44 +145,45 @@ class LeakyReLUOp : public Operator {
CHECK_EQ(req.size(), expected);
CHECK_EQ(in_data.size(), expected);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> output;
Tensor<xpu, 4> data;
Tensor<xpu, 4> gdata;
Tensor<xpu, 4> grad;
Tensor<xpu, 4> mask;
Tensor<xpu, 1> weight;
Tensor<xpu, 1> grad_weight;
Tensor<xpu, 4, DType> output;
Tensor<xpu, 4, DType> data;
Tensor<xpu, 4, DType> gdata;
Tensor<xpu, 4, DType> grad;
Tensor<xpu, 4, DType> mask;
Tensor<xpu, 1, DType> weight;
Tensor<xpu, 1, DType> grad_weight;
if (out_grad[leakyrelu::kOut].ndim() == 2) {
Shape<4> dshape = Shape4(out_grad[leakyrelu::kOut].shape_[0],
out_grad[leakyrelu::kOut].shape_[1], 1, 1);
grad = out_grad[leakyrelu::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
gdata = in_grad[leakyrelu::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
output = out_data[leakyrelu::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
grad = out_grad[leakyrelu::kOut].get_with_shape<xpu, 4, DType>(dshape, s);
gdata = in_grad[leakyrelu::kData].get_with_shape<xpu, 4, DType>(dshape, s);
output = out_data[leakyrelu::kOut].get_with_shape<xpu, 4, DType>(dshape, s);
if (param_.act_type == leakyrelu::kRReLU) {
mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 4, real_t>(dshape, s);
mask = out_data[leakyrelu::kMask].get_with_shape<xpu, 4, DType>(dshape, s);
}
if (param_.act_type == leakyrelu::kPReLU) {
data = in_data[leakyrelu::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
data = in_data[leakyrelu::kData].get_with_shape<xpu, 4, DType>(dshape, s);
}
} else {
grad = out_grad[leakyrelu::kOut].get<xpu, 4, real_t>(s);
gdata = in_grad[leakyrelu::kData].get<xpu, 4, real_t>(s);
output = out_data[leakyrelu::kOut].get<xpu, 4, real_t>(s);
grad = out_grad[leakyrelu::kOut].get<xpu, 4, DType>(s);
gdata = in_grad[leakyrelu::kData].get<xpu, 4, DType>(s);
output = out_data[leakyrelu::kOut].get<xpu, 4, DType>(s);
if (param_.act_type == leakyrelu::kRReLU) {
mask = out_data[leakyrelu::kMask].get<xpu, 4, real_t>(s);
mask = out_data[leakyrelu::kMask].get<xpu, 4, DType>(s);
}
if (param_.act_type == leakyrelu::kPReLU) {
data = in_data[leakyrelu::kData].get<xpu, 4, real_t>(s);
data = in_data[leakyrelu::kData].get<xpu, 4, DType>(s);
}
}
switch (param_.act_type) {
case leakyrelu::kLeakyReLU: {
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::xelu_grad>(output, param_.slope) * grad);
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::xelu_grad>(output, slope) * grad);
break;
}
case leakyrelu::kPReLU: {
weight = in_data[leakyrelu::kGamma].get<xpu, 1, real_t>(s);
grad_weight = in_grad[leakyrelu::kGamma].get<xpu, 1, real_t>(s);
weight = in_data[leakyrelu::kGamma].get<xpu, 1, DType>(s);
grad_weight = in_grad[leakyrelu::kGamma].get<xpu, 1, DType>(s);
grad_weight = sumall_except_dim<1>(F<prelu_grad>(data) * grad);
gdata = F<mshadow_op::xelu_grad>(output, broadcast<1>(weight, data.shape_)) * grad;
break;
Expand All @@ -185,7 +193,8 @@ class LeakyReLUOp : public Operator {
break;
}
case leakyrelu::kELU: {
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::elu_grad>(output, param_.slope) * grad);
ScalarExp<DType> slope = ScalarExp<DType>(param_.slope);
Assign(gdata, req[leakyrelu::kData], F<mshadow_op::elu_grad>(output, slope) * grad);
break;
}
default:
Expand All @@ -198,7 +207,7 @@ class LeakyReLUOp : public Operator {
}; // class LeakyReLUOp

template<typename xpu>
Operator* CreateOp(LeakyReLUParam type);
Operator* CreateOp(LeakyReLUParam type, int dtype);

#if DMLC_USE_CXX11
class LeakyReLUProp : public OperatorProperty {
Expand Down Expand Up @@ -315,7 +324,13 @@ class LeakyReLUProp : public OperatorProperty {
}
}

Operator* CreateOperator(Context ctx) const override;
Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented";
return NULL;
}

Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;

private:
LeakyReLUParam param_;
Expand Down
17 changes: 13 additions & 4 deletions src/operator/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@
namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(LeakyReLUParam param) {
return new LeakyReLUOp<cpu>(param);
Operator *CreateOp<cpu>(LeakyReLUParam param, int dtype) {
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new LeakyReLUOp<cpu, DType>(param);
});
return op;
}

Operator *LeakyReLUProp::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateOp, param_);
Operator *LeakyReLUProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
}

DMLC_REGISTER_PARAMETER(LeakyReLUParam);
Expand Down
8 changes: 6 additions & 2 deletions src/operator/leaky_relu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
namespace mxnet {
namespace op {
template<>
Operator *CreateOp<gpu>(LeakyReLUParam param) {
return new LeakyReLUOp<gpu>(param);
Operator *CreateOp<gpu>(LeakyReLUParam param, int dtype) {
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new LeakyReLUOp<gpu, DType>(param);
});
return op;
}

} // namespace op
Expand Down
52 changes: 37 additions & 15 deletions src/operator/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
}
};

template<typename xpu, typename Reducer>
template<typename xpu, typename Reducer, typename DType>
class PoolingOp : public Operator {
public:
explicit PoolingOp(PoolingParam p) {
Expand All @@ -77,8 +77,8 @@ class PoolingOp : public Operator {
CHECK_EQ(in_data.size(), 1);
CHECK_EQ(out_data.size(), 1);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> data = in_data[pool_enum::kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> out = out_data[pool_enum::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4, DType> data = in_data[pool_enum::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out = out_data[pool_enum::kOut].get<xpu, 4, DType>(s);
mshadow::Shape<2> out_shape = Shape2(out.shape_[2], out.shape_[3]);
if (param_.pool_type == pool_enum::kMaxPooling || param_.pool_type == pool_enum::kSumPooling) {
Assign(out,
Expand All @@ -90,12 +90,12 @@ class PoolingOp : public Operator {
param_.global_pool ? 1 : param_.stride[0],
param_.global_pool ? 1 : param_.stride[1]));
} else if (param_.pool_type == pool_enum::kAvgPooling) {
ScalarExp<DType> x = ScalarExp<DType>(1.0f / (param_.global_pool ?
data.shape_[2] * data.shape_[3] :
param_.kernel[0] * param_.kernel[1]));
Assign(out,
req[pool_enum::kOut],
(1.0f / (param_.global_pool ?
data.shape_[2] * data.shape_[3] :
param_.kernel[0] * param_.kernel[1])) * \
pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
x * pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
out_shape,
param_.global_pool ? data.shape_[2] : param_.kernel[0],
param_.global_pool ? data.shape_[3] : param_.kernel[1],
Expand All @@ -120,10 +120,10 @@ class PoolingOp : public Operator {
CHECK_EQ(in_grad.size(), 1);
// TODO(bing): remove pad (0,0)
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> grad = out_grad[pool_enum::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> data = in_data[pool_enum::kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> output_data = out_data[pool_enum::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> input_grad = in_grad[pool_enum::kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 4, DType> grad = out_grad[pool_enum::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> data = in_data[pool_enum::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> output_data = out_data[pool_enum::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> input_grad = in_grad[pool_enum::kData].get<xpu, 4, DType>(s);

mshadow::Shape<2> in_shape = Shape2(data.shape_[2], data.shape_[3]);

Expand All @@ -140,9 +140,9 @@ class PoolingOp : public Operator {
param_.pad[0],
param_.pad[1]));
} else if (param_.pool_type == pool_enum::kAvgPooling) {
ScalarExp<DType> x = ScalarExp<DType>(1.0f / param_.kernel[0] / param_.kernel[1]);
Assign(input_grad, req[pool_enum::kData],
(1.0f / param_.kernel[0] / param_.kernel[1]) *\
crop(unpool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
x * crop(unpool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
pad(output_data, 0, 0),
pad(grad, 0, 0),
param_.global_pool ? in_shape[0] : param_.kernel[0],
Expand All @@ -160,7 +160,7 @@ class PoolingOp : public Operator {
}; // class PoolingOp

template<typename xpu>
Operator* CreateOp(PoolingParam param);
Operator* CreateOp(PoolingParam param, int dtype);


#if DMLC_USE_CXX11
Expand Down Expand Up @@ -198,6 +198,22 @@ class PoolingProp : public OperatorProperty {
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_EQ(in_type->size(), 1);
int dtype = in_type->at(0);

if (dtype == -1) {
LOG(FATAL) << "input type to pooling is not specified.";
return false;
}

out_type->clear();
out_type->push_back(dtype);
return true;
}

OperatorProperty* Copy() const override {
PoolingProp *prop_sym = new PoolingProp();
prop_sym->param_ = this->param_;
Expand Down Expand Up @@ -227,7 +243,13 @@ class PoolingProp : public OperatorProperty {
#endif
}

Operator* CreateOperator(Context ctx) const override;
Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented";
return NULL;
}

Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;

private:
PoolingParam param_;
Expand Down
Loading