@@ -787,6 +787,284 @@ void BipartiteMatchingBackward(const nnvm::NodeAttrs& attrs,
787
787
});
788
788
}
789
789
790
+
791
+ inline bool BoxEncodeShape (const nnvm::NodeAttrs& attrs,
792
+ mxnet::ShapeVector *in_attrs,
793
+ mxnet::ShapeVector *out_attrs) {
794
+ CHECK_EQ (in_attrs->size (), 6U );
795
+ CHECK_EQ (out_attrs->size (), 2U );
796
+ mxnet::TShape& sshape = (*in_attrs)[0 ];
797
+ mxnet::TShape& mshape = (*in_attrs)[1 ];
798
+ mxnet::TShape& ashape = (*in_attrs)[2 ];
799
+ mxnet::TShape& rshape = (*in_attrs)[3 ];
800
+
801
+ CHECK_EQ (sshape.ndim (), 2 )
802
+ << " samples shape must have dim == 2, "
803
+ << sshape.ndim () << " provided" ;
804
+
805
+ CHECK_GE (mshape.ndim (), 2 )
806
+ << " matches shape must have dim == 2, "
807
+ << mshape.ndim () << " provided" ;
808
+
809
+ CHECK_GE (ashape.ndim (), 3 )
810
+ << " matches shape must have dim == 3, "
811
+ << ashape.ndim () << " provided" ;
812
+ int ldim = ashape[ashape.ndim () - 1 ];
813
+ CHECK_EQ (ldim, 4 )
814
+ << " last dimension of anchors must be 4, "
815
+ << ldim << " provided" ;
816
+
817
+ CHECK_GE (rshape.ndim (), 3 )
818
+ << " refs shape must have dim == 3, "
819
+ << ashape.ndim () << " provided" ;
820
+ ldim = rshape[rshape.ndim () - 1 ];
821
+ CHECK_EQ (ldim, 4 )
822
+ << " last dimension of anchors must be 4, "
823
+ << ldim << " provided" ;
824
+
825
+ // asign input shape
826
+ SHAPE_ASSIGN_CHECK (*in_attrs, 4 , mshadow::Shape1 (4 ));
827
+ SHAPE_ASSIGN_CHECK (*in_attrs, 5 , mshadow::Shape1 (4 ));
828
+
829
+ // assign output shape
830
+ mxnet::TShape oshape = ashape;
831
+ SHAPE_ASSIGN_CHECK (*out_attrs, 0 , oshape);
832
+ SHAPE_ASSIGN_CHECK (*out_attrs, 1 , oshape);
833
+ return shape_is_known (oshape);
834
+ }
835
+
836
+ struct box_encode {
837
+ template <typename DType>
838
+ MSHADOW_XINLINE static void Map (index_t i, DType *out_targets, DType *out_masks,
839
+ const DType *samples, const DType *matches,
840
+ const DType *anchors, const DType *refs,
841
+ const DType *means, const DType *stds,
842
+ const int m, const int n) {
843
+ index_t j = i / n;
844
+ index_t match = matches[i];
845
+ // xmin: 0, ymin:1, xmax: 2, ymax: 3
846
+ // x:0, y:1, w:2, h:3
847
+ index_t ref_index = (j * m + match) * 4 ;
848
+ DType ref_xmin = refs[ref_index + 0 ];
849
+ DType ref_ymin = refs[ref_index + 1 ];
850
+ DType ref_width = refs[ref_index + 2 ] - ref_xmin;
851
+ DType ref_height = refs[ref_index + 3 ] - ref_ymin;
852
+ DType ref_x = ref_xmin + ref_width * 0.5 ;
853
+ DType ref_y = ref_ymin + ref_height * 0.5 ;
854
+ index_t a_index = i * 4 ;
855
+ DType a_xmin = anchors[a_index + 0 ];
856
+ DType a_ymin = anchors[a_index + 1 ];
857
+ DType a_width = anchors[a_index + 2 ] - a_xmin;
858
+ DType a_height = anchors[a_index + 3 ] - a_ymin;
859
+ DType a_x = a_xmin + a_width * 0.5 ;
860
+ DType a_y = a_ymin + a_height * 0.5 ;
861
+ DType valid = samples[i] > 0.5 ? 1.0 : 0.0 ;
862
+ out_masks[a_index + 0 ] = valid;
863
+ out_masks[a_index + 1 ] = valid;
864
+ out_masks[a_index + 2 ] = valid;
865
+ out_masks[a_index + 3 ] = valid;
866
+ out_targets[a_index + 0 ] = valid > static_cast <DType>(0.5 ) ?
867
+ ((ref_x - a_x) / a_width - static_cast <DType>(means[0 ])) /
868
+ static_cast <DType>(stds[0 ]) : static_cast <DType>(0.0 );
869
+ out_targets[a_index + 1 ] = valid > static_cast <DType>(0.5 ) ?
870
+ ((ref_y - a_y) / a_height - static_cast <DType>(means[1 ])) /
871
+ static_cast <DType>(stds[1 ]) : static_cast <DType>(0.0 );
872
+ out_targets[a_index + 2 ] = valid > static_cast <DType>(0.5 ) ?
873
+ (log (ref_width / a_width) - static_cast <DType>(means[2 ])) /
874
+ static_cast <DType>(stds[2 ]) : static_cast <DType>(0.0 );
875
+ out_targets[a_index + 3 ] = valid > static_cast <DType>(0.5 ) ?
876
+ (log (ref_height / a_height) - static_cast <DType>(means[3 ])) /
877
+ static_cast <DType>(stds[3 ]) : static_cast <DType>(0.0 );
878
+ }
879
+ };
880
+
881
+ template <typename xpu>
882
+ void BoxEncodeForward (const nnvm::NodeAttrs& attrs,
883
+ const OpContext& ctx,
884
+ const std::vector<TBlob>& inputs,
885
+ const std::vector<OpReqType>& req,
886
+ const std::vector<TBlob>& outputs) {
887
+ using namespace mshadow ;
888
+ using namespace mshadow ::expr;
889
+ using namespace mxnet_op ;
890
+ CHECK_EQ (inputs.size (), 6U );
891
+ CHECK_EQ (outputs.size (), 2U );
892
+ Stream<xpu> *s = ctx.get_stream <xpu>();
893
+ // samples, matches, anchors, refs, means, stds
894
+ mxnet::TShape anchor_shape = inputs[2 ].shape_ ;
895
+ int loop_size = anchor_shape.ProdShape (0 , 2 );
896
+ int b = anchor_shape[0 ];
897
+ int n = anchor_shape[1 ];
898
+ int m = inputs[3 ].shape_ [1 ];
899
+ MSHADOW_REAL_TYPE_SWITCH (outputs[0 ].type_flag_ , DType, {
900
+ Tensor<xpu, 2 , DType> samples = inputs[0 ]
901
+ .get_with_shape <xpu, 2 , DType>(Shape2 (b, n), s);
902
+ Tensor<xpu, 2 , DType> matches = inputs[1 ]
903
+ .get_with_shape <xpu, 2 , DType>(Shape2 (b, n), s);
904
+ Tensor<xpu, 3 , DType> anchors = inputs[2 ]
905
+ .get_with_shape <xpu, 3 , DType>(Shape3 (b, n, 4 ), s);
906
+ Tensor<xpu, 3 , DType> refs = inputs[3 ]
907
+ .get_with_shape <xpu, 3 , DType>(Shape3 (b, m, 4 ), s);
908
+ Tensor<xpu, 1 , DType> means = inputs[4 ]
909
+ .get_with_shape <xpu, 1 , DType>(Shape1 (4 ), s);
910
+ Tensor<xpu, 1 , DType> stds = inputs[5 ]
911
+ .get_with_shape <xpu, 1 , DType>(Shape1 (4 ), s);
912
+ Tensor<xpu, 3 , DType> out_targets = outputs[0 ]
913
+ .get_with_shape <xpu, 3 , DType>(Shape3 (b, n, 4 ), s);
914
+ Tensor<xpu, 3 , DType> out_masks = outputs[1 ]
915
+ .get_with_shape <xpu, 3 , DType>(Shape3 (b, n, 4 ), s);
916
+
917
+ Kernel<box_encode, xpu>::Launch (s, loop_size, out_targets.dptr_ ,
918
+ out_masks.dptr_ , samples.dptr_ , matches.dptr_ , anchors.dptr_ ,
919
+ refs.dptr_ , means.dptr_ , stds.dptr_ , m, n);
920
+ });
921
+ }
922
+
923
+ struct BoxDecodeParam : public dmlc ::Parameter<BoxDecodeParam> {
924
+ float std0;
925
+ float std1;
926
+ float std2;
927
+ float std3;
928
+ float clip;
929
+ int format;
930
+ DMLC_DECLARE_PARAMETER (BoxDecodeParam) {
931
+ DMLC_DECLARE_FIELD (std0).set_default (1.0 )
932
+ .describe (" value to be divided from the 1st encoded values" );
933
+ DMLC_DECLARE_FIELD (std1).set_default (1.0 )
934
+ .describe (" value to be divided from the 2nd encoded values" );
935
+ DMLC_DECLARE_FIELD (std2).set_default (1.0 )
936
+ .describe (" value to be divided from the 3rd encoded values" );
937
+ DMLC_DECLARE_FIELD (std3).set_default (1.0 )
938
+ .describe (" value to be divided from the 4th encoded values" );
939
+ DMLC_DECLARE_FIELD (clip).set_default (-1.0 )
940
+ .describe (" If larger than 0, bounding box target will be clipped to this value." );
941
+ DMLC_DECLARE_FIELD (format).set_default (box_common_enum::kCenter )
942
+ .add_enum (" corner" , box_common_enum::kCorner )
943
+ .add_enum (" center" , box_common_enum::kCenter )
944
+ .describe (" The box encoding type. \n "
945
+ " \" corner\" means boxes are encoded as [xmin, ymin, xmax, ymax],"
946
+ " \" center\" means boxes are encodes as [x, y, width, height]." );
947
+ }
948
+ }; // BoxDecodeParam
949
+
950
+ inline bool BoxDecodeShape (const nnvm::NodeAttrs& attrs,
951
+ mxnet::ShapeVector *in_attrs,
952
+ mxnet::ShapeVector *out_attrs) {
953
+ CHECK_EQ (in_attrs->size (), 2U );
954
+ CHECK_EQ (out_attrs->size (), 1U );
955
+ mxnet::TShape& dshape = (*in_attrs)[0 ];
956
+ mxnet::TShape& ashape = (*in_attrs)[1 ];
957
+
958
+ CHECK_EQ (dshape.ndim (), 3 )
959
+ << " data shape must have dim == 3, "
960
+ << dshape.ndim () << " provided" ;
961
+ int ldim = dshape[dshape.ndim () - 1 ];
962
+ CHECK_EQ (ldim, 4 )
963
+ << " last dimension of data must be 4, "
964
+ << ldim << " provided" ;
965
+
966
+ CHECK_GE (ashape.ndim (), 3 )
967
+ << " anchors shape must have dim == 3, "
968
+ << ashape.ndim () << " provided" ;
969
+ ldim = ashape[ashape.ndim () - 1 ];
970
+ CHECK_EQ (ldim, 4 )
971
+ << " last dimension of anchors must be 4, "
972
+ << ldim << " provided" ;
973
+
974
+ // assign output shape
975
+ mxnet::TShape oshape = dshape;
976
+ SHAPE_ASSIGN_CHECK (*out_attrs, 0 , oshape);
977
+ return shape_is_known (oshape);
978
+ }
979
+
980
+ template <int anchor_encode, bool has_clip>
981
+ struct box_decode {
982
+ template <typename DType>
983
+ MSHADOW_XINLINE static void Map (index_t i, DType *out, const DType *x,
984
+ const DType *anchors, const DType std0,
985
+ const DType std1, const DType std2,
986
+ const DType std3, const DType clip,
987
+ const int n) {
988
+ index_t index = i * 4 ;
989
+ index_t a_index = (i % n) * 4 ;
990
+ DType a_x = anchors[a_index + 0 ];
991
+ DType a_y = anchors[a_index + 1 ];
992
+ DType a_width = anchors[a_index + 2 ];
993
+ DType a_height = anchors[a_index + 3 ];
994
+ if (box_common_enum::kCorner == anchor_encode) {
995
+ // a_x = xmin, a_y = ymin, a_width = xmax, a_height = ymax
996
+ a_width = a_width - a_x;
997
+ a_height = a_height - a_y;
998
+ a_x = a_x + a_width * 0.5 ;
999
+ a_y = a_y + a_height * 0.5 ;
1000
+ }
1001
+ DType ox = x[index + 0 ] * std0 * a_width + a_x;
1002
+ DType oy = x[index + 1 ] * std1 * a_height + a_y;
1003
+ DType dw = x[index + 2 ] * std2;
1004
+ DType dh = x[index + 3 ] * std3;
1005
+ if (has_clip) {
1006
+ dw = dw < clip ? dw : clip;
1007
+ dh = dh < clip ? dh : clip;
1008
+ }
1009
+ dw = exp (dw);
1010
+ dh = exp (dh);
1011
+ DType ow = dw * a_width * 0.5 ;
1012
+ DType oh = dh * a_height * 0.5 ;
1013
+ out[index + 0 ] = ox - ow;
1014
+ out[index + 1 ] = oy - oh;
1015
+ out[index + 2 ] = ox + ow;
1016
+ out[index + 3 ] = oy + oh;
1017
+ }
1018
+ };
1019
+
1020
+ template <typename xpu>
1021
+ void BoxDecodeForward (const nnvm::NodeAttrs& attrs,
1022
+ const OpContext& ctx,
1023
+ const std::vector<TBlob>& inputs,
1024
+ const std::vector<OpReqType>& req,
1025
+ const std::vector<TBlob>& outputs) {
1026
+ using namespace mshadow ;
1027
+ using namespace mshadow ::expr;
1028
+ using namespace mxnet_op ;
1029
+ CHECK_EQ (inputs.size (), 2U );
1030
+ CHECK_EQ (outputs.size (), 1U );
1031
+ Stream<xpu> *s = ctx.get_stream <xpu>();
1032
+ mxnet::TShape x_shape = inputs[0 ].shape_ ;
1033
+ int b = x_shape[0 ];
1034
+ int n = x_shape[1 ];
1035
+ int loop_size = b * n;
1036
+ const BoxDecodeParam& param = nnvm::get<BoxDecodeParam>(attrs.parsed );
1037
+ MSHADOW_REAL_TYPE_SWITCH (outputs[0 ].type_flag_ , DType, {
1038
+ Tensor<xpu, 3 , DType> data = inputs[0 ]
1039
+ .get_with_shape <xpu, 3 , DType>(Shape3 (b, n, 4 ), s);
1040
+ Tensor<xpu, 3 , DType> anchors = inputs[1 ]
1041
+ .get_with_shape <xpu, 3 , DType>(Shape3 (1 , n, 4 ), s);
1042
+ Tensor<xpu, 3 , DType> out = outputs[0 ]
1043
+ .get_with_shape <xpu, 3 , DType>(Shape3 (b, n, 4 ), s);
1044
+ if (box_common_enum::kCorner == param.format && param.clip > 0.0 ) {
1045
+ Kernel<box_decode<box_common_enum::kCorner , true >, xpu>::Launch (s, loop_size,
1046
+ out.dptr_ , data.dptr_ , anchors.dptr_ , static_cast <DType>(param.std0 ),
1047
+ static_cast <DType>(param.std1 ), static_cast <DType>(param.std2 ),
1048
+ static_cast <DType>(param.std3 ), static_cast <DType>(param.clip ), n);
1049
+ } else if (box_common_enum::kCenter == param.format && param.clip > 0.0 ) {
1050
+ Kernel<box_decode<box_common_enum::kCenter , true >, xpu>::Launch (s, loop_size,
1051
+ out.dptr_ , data.dptr_ , anchors.dptr_ , static_cast <DType>(param.std0 ),
1052
+ static_cast <DType>(param.std1 ), static_cast <DType>(param.std2 ),
1053
+ static_cast <DType>(param.std3 ), static_cast <DType>(param.clip ), n);
1054
+ } else if (box_common_enum::kCorner == param.format && param.clip <= 0.0 ) {
1055
+ Kernel<box_decode<box_common_enum::kCorner , false >, xpu>::Launch (s, loop_size,
1056
+ out.dptr_ , data.dptr_ , anchors.dptr_ , static_cast <DType>(param.std0 ),
1057
+ static_cast <DType>(param.std1 ), static_cast <DType>(param.std2 ),
1058
+ static_cast <DType>(param.std3 ), static_cast <DType>(param.clip ), n);
1059
+ } else {
1060
+ Kernel<box_decode<box_common_enum::kCenter , false >, xpu>::Launch (s, loop_size,
1061
+ out.dptr_ , data.dptr_ , anchors.dptr_ , static_cast <DType>(param.std0 ),
1062
+ static_cast <DType>(param.std1 ), static_cast <DType>(param.std2 ),
1063
+ static_cast <DType>(param.std3 ), static_cast <DType>(param.clip ), n);
1064
+ }
1065
+ });
1066
+ }
1067
+
790
1068
} // namespace op
791
1069
} // namespace mxnet
792
1070
0 commit comments