Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

OP ROIPooling CPU fix and DType support #3011

Merged
merged 2 commits into from
Aug 15, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions src/operator/roi_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct ROIPoolingParam : public dmlc::Parameter<ROIPoolingParam> {
}
};

template<typename xpu>
template<typename xpu, typename DType>
class ROIPoolingOp : public Operator {
public:
explicit ROIPoolingOp(ROIPoolingParam p) {
Expand All @@ -61,10 +61,10 @@ class ROIPoolingOp : public Operator {
CHECK_EQ(out_data[roipool::kMaxIdx].shape_[0], in_data[roipool::kBox].shape_[0]);
Stream<xpu> *s = ctx.get_stream<xpu>();

Tensor<xpu, 4> data = in_data[roipool::kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 2> bbox = in_data[roipool::kBox].get<xpu, 2, real_t>(s);
Tensor<xpu, 4> out = out_data[roipool::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, real_t>(s);
Tensor<xpu, 4, DType> data = in_data[roipool::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 2, DType> bbox = in_data[roipool::kBox].get<xpu, 2, DType>(s);
Tensor<xpu, 4, DType> out = out_data[roipool::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, DType>(s);
CHECK_EQ(data.CheckContiguous(), true);
CHECK_EQ(bbox.CheckContiguous(), true);
CHECK_EQ(out.CheckContiguous(), true);
Expand All @@ -90,10 +90,10 @@ class ROIPoolingOp : public Operator {
CHECK_EQ(req[roipool::kOut], kWriteTo);
Stream<xpu> *s = ctx.get_stream<xpu>();

Tensor<xpu, 4> grad_out = out_grad[roipool::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 2> bbox = in_data[roipool::kBox].get<xpu, 2, real_t>(s);
Tensor<xpu, 4> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> grad_in = in_grad[roipool::kData].get<xpu, 4, real_t>(s);
Tensor<xpu, 4, DType> grad_out = out_grad[roipool::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 2, DType> bbox = in_data[roipool::kBox].get<xpu, 2, DType>(s);
Tensor<xpu, 4, DType> max_idx = out_data[roipool::kMaxIdx].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad_in = in_grad[roipool::kData].get<xpu, 4, DType>(s);
CHECK_EQ(grad_out.CheckContiguous(), true);
CHECK_EQ(bbox.CheckContiguous(), true);
CHECK_EQ(max_idx.CheckContiguous(), true);
Expand All @@ -108,7 +108,7 @@ class ROIPoolingOp : public Operator {

// Decalre Factory function, used for dispatch specialization
template<typename xpu>
Operator* CreateOp(ROIPoolingParam param);
Operator* CreateOp(ROIPoolingParam param, int dtype);

#if DMLC_USE_CXX11
class ROIPoolingProp : public OperatorProperty {
Expand Down Expand Up @@ -162,6 +162,20 @@ class ROIPoolingProp : public OperatorProperty {
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_EQ(in_type->size(), 2);
int dtype = (*in_type)[0];
CHECK_EQ(dtype, (*in_type)[1]);
CHECK_NE(dtype, -1) << "Input must have specified type";

out_type->clear();
out_type->push_back(dtype);
out_type->push_back(dtype);
return true;
}

OperatorProperty* Copy() const override {
ROIPoolingProp* roi_pooling_sym = new ROIPoolingProp();
roi_pooling_sym->param_ = this->param_;
Expand All @@ -180,7 +194,13 @@ class ROIPoolingProp : public OperatorProperty {
return {out_grad[roipool::kOut], in_data[roipool::kBox], out_data[roipool::kMaxIdx]};
}

Operator* CreateOperator(Context ctx) const override;
Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
}

Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;

private:
ROIPoolingParam param_;
Expand Down
34 changes: 20 additions & 14 deletions src/operator/roi_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,23 @@ inline void ROIPoolBackward(const Tensor<cpu, 4, Dtype> &in_grad,
for (int h = 0; h < height_; ++h) {
for (int w = 0; w < width_; ++w) {
int offset_bottom_diff = (b * channels_ + c) * height_ * width_;
offset_bottom_diff += h * height_ + w;
offset_bottom_diff += h * width_ + w;

Dtype gradient = 0;
// Accumulate gradient over all ROIs that pooled this element
for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
int roi_batch_ind = bottom_rois[0];
const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5;
int roi_batch_ind = offset_bottom_rois[0];
assert(roi_batch_ind >= 0);
assert(roi_batch_ind < batch_size_);
if (b != roi_batch_ind) {
continue;
}

int roi_start_w = round(bottom_rois[1] * spatial_scale_);
int roi_start_h = round(bottom_rois[2] * spatial_scale_);
int roi_end_w = round(bottom_rois[3] * spatial_scale_);
int roi_end_h = round(bottom_rois[4] * spatial_scale_);
int roi_start_w = round(offset_bottom_rois[1] * spatial_scale_);
int roi_start_h = round(offset_bottom_rois[2] * spatial_scale_);
int roi_end_w = round(offset_bottom_rois[3] * spatial_scale_);
int roi_end_h = round(offset_bottom_rois[4] * spatial_scale_);

bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
h >= roi_start_h && h <= roi_end_h);
Expand Down Expand Up @@ -191,9 +192,6 @@ inline void ROIPoolBackward(const Tensor<cpu, 4, Dtype> &in_grad,
}
}
}

// Increment ROI data pointer
bottom_rois += bbox.size(1);
}
bottom_diff[offset_bottom_diff] = gradient;
}
Expand All @@ -209,13 +207,21 @@ namespace mxnet {
namespace op {

template<>
Operator* CreateOp<cpu>(ROIPoolingParam param) {
return new ROIPoolingOp<cpu>(param);
Operator *CreateOp<cpu>(ROIPoolingParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new ROIPoolingOp<cpu, DType>(param);
});
return op;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me. I'll add some testing code for pooling layer later. After that u can follow me to create a testing script to validate ROI pooling result.
It would be better if u can add ; after the } to make it consistant with other operators. Many thanks. ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. I didn't quite follow the ; after } advice. Would you please elaborate on that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK.
Maybe I'll do that for you later. LOL.


// DO_BIND_DISPATCH comes from static_operator_common.h
Operator* ROIPoolingProp::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateOp, param_);
Operator *ROIPoolingProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
std::vector<int> out_type, aux_type;
CHECK(InferType(in_type, &out_type, &aux_type));
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
}

DMLC_REGISTER_PARAMETER(ROIPoolingParam);
Expand Down
18 changes: 7 additions & 11 deletions src/operator/roi_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
#include <algorithm>
#include <vector>

#define ROIPOOLING_CUDA_CHECK(condition) \
/* Code block avoids redefinition of cudaError_t error */ \
do { \
cudaError_t error = condition; \
CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
} while (0)

namespace mshadow {
namespace cuda {

Expand Down Expand Up @@ -117,7 +110,6 @@ inline void ROIPoolForward(const Tensor<gpu, 4, Dtype> &out,
ROIPoolForwardKernel<Dtype><<<dimGrid, dimBlock, 0, stream>>>(
count, bottom_data, spatial_scale, channels, height, width,
pooled_height, pooled_width, bottom_rois, top_data, argmax_data);
ROIPOOLING_CUDA_CHECK(cudaPeekAtLastError());
}

template<typename Dtype>
Expand Down Expand Up @@ -221,7 +213,6 @@ inline void ROIPoolBackward(const Tensor<gpu, 4, Dtype> &in_grad,
ROIPoolBackwardKernel<Dtype><<<dimGrid, dimBlock, 0, stream>>>(
count, top_diff, argmax_data, num_rois, spatial_scale, channels, height, width,
pooled_height, pooled_width, bottom_diff, bottom_rois);
ROIPOOLING_CUDA_CHECK(cudaPeekAtLastError());
}

} // namespace cuda
Expand Down Expand Up @@ -251,8 +242,13 @@ namespace mxnet {
namespace op {

template<>
Operator* CreateOp<gpu>(ROIPoolingParam param) {
return new ROIPoolingOp<gpu>(param);
Operator* CreateOp<gpu>(ROIPoolingParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new ROIPoolingOp<gpu, DType>(param);
});
return op;
}

} // namespace op
} // namespace mxnet
10 changes: 10 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,15 @@ def test_support_vector_machine_l2_svm():
grad_np = grad_np.astype(np.float32)
assert_allclose(grad_np, grad.asnumpy())

def test_roipooling():
data = mx.symbol.Variable(name='data')
rois = mx.symbol.Variable(name='rois')
test = mx.symbol.ROIPooling(data=data, rois=rois, pooled_size=(6, 6), spatial_scale=1)

x1 = np.random.rand(4, 3, 12, 8)
x2 = np.array([[0, 1, 1, 6, 6], [2, 6, 2, 7, 11], [1, 3, 1, 5, 10], [0, 3, 3, 3, 3]])

check_numeric_gradient(test, [x1, x2], numeric_eps=1e-4, check_eps=1e-1)

if __name__ == '__main__':
test_expand_dims()
Expand Down Expand Up @@ -1478,3 +1487,4 @@ def test_support_vector_machine_l2_svm():
test_correlation()
test_support_vector_machine_l1_svm()
test_support_vector_machine_l2_svm()
test_roipooling()