Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit db79553

Browse files
committed
amp multicast support casting to narrowest type
cpplint cpplint
1 parent 5131a44 commit db79553

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

src/operator/contrib/bounding_box-inl.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -864,13 +864,17 @@ struct box_encode {
864864
out_masks[a_index + 2] = valid;
865865
out_masks[a_index + 3] = valid;
866866
out_targets[a_index + 0] = valid > static_cast<DType>(0.5) ?
867-
((ref_x - a_x) / a_width - static_cast<DType>(means[0])) / static_cast<DType>(stds[0]) : static_cast<DType>(0.0);
867+
((ref_x - a_x) / a_width - static_cast<DType>(means[0])) /
868+
static_cast<DType>(stds[0]) : static_cast<DType>(0.0);
868869
out_targets[a_index + 1] = valid > static_cast<DType>(0.5) ?
869-
((ref_y - a_y) / a_height - static_cast<DType>(means[1])) / static_cast<DType>(stds[1]) : static_cast<DType>(0.0);
870+
((ref_y - a_y) / a_height - static_cast<DType>(means[1])) /
871+
static_cast<DType>(stds[1]) : static_cast<DType>(0.0);
870872
out_targets[a_index + 2] = valid > static_cast<DType>(0.5) ?
871-
(log(ref_width / a_width) - static_cast<DType>(means[2])) / static_cast<DType>(stds[2]) : static_cast<DType>(0.0);
873+
(log(ref_width / a_width) - static_cast<DType>(means[2])) /
874+
static_cast<DType>(stds[2]) : static_cast<DType>(0.0);
872875
out_targets[a_index + 3] = valid > static_cast<DType>(0.5) ?
873-
(log(ref_height / a_height) - static_cast<DType>(means[3])) / static_cast<DType>(stds[3]) : static_cast<DType>(0.0);
876+
(log(ref_height / a_height) - static_cast<DType>(means[3])) /
877+
static_cast<DType>(stds[3]) : static_cast<DType>(0.0);
874878
}
875879
};
876880

src/operator/contrib/bounding_box.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ NNVM_REGISTER_OP(_contrib_box_encode)
216216
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<6, 2>)
217217
.set_attr<FCompute>("FCompute<cpu>", BoxEncodeForward<cpu>)
218218
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
219-
.add_argument("samples", "NDArray-or-Symbol", "(B, N) value +1 (positive), -1 (negative), 0 (ignore)")
219+
.add_argument("samples", "NDArray-or-Symbol", "(B, N) value +1 (positive), -1 (negative), "
220+
"0 (ignore)")
220221
.add_argument("matches", "NDArray-or-Symbol", "(B, N) value range [0, M)")
221222
.add_argument("anchors", "NDArray-or-Symbol", "(B, N, 4) encoded in corner")
222223
.add_argument("refs", "NDArray-or-Symbol", "(B, N, 4) encoded in corner")

src/operator/tensor/amp_cast.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@ struct AMPCastParam : public dmlc::Parameter<AMPCastParam> {
4848

4949
struct AMPMultiCastParam : public dmlc::Parameter<AMPMultiCastParam> {
5050
int num_outputs;
51+
bool cast_narrow;
5152

5253
DMLC_DECLARE_PARAMETER(AMPMultiCastParam) {
5354
DMLC_DECLARE_FIELD(num_outputs)
5455
.describe("Number of input/output pairs to be casted to the widest type.");
56+
DMLC_DECLARE_FIELD(cast_narrow).set_default(false)
57+
.describe("Whether to cast to the narrowest type");
5558
}
5659
};
5760

@@ -80,10 +83,10 @@ inline bool AMPMultiCastType(const nnvm::NodeAttrs& attrs,
8083
CHECK_EQ(in_attrs->size(), param.num_outputs);
8184
CHECK_EQ(out_attrs->size(), param.num_outputs);
8285
bool ret = true;
83-
int widest_type = kFloat16;
86+
int widest_type = param.cast_narrow ? kFloat32 : kFloat16;
8487
for (int i = 0; i < param.num_outputs; ++i) {
8588
if ((*in_attrs)[i] == kFloat32 || (*out_attrs)[i] == kFloat32) {
86-
widest_type = kFloat32;
89+
widest_type = param.cast_narrow ? kFloat16 : kFloat32;
8790
}
8891
}
8992
for (int i = 0; i < param.num_outputs; ++i) {

0 commit comments

Comments
 (0)