Skip to content

Commit 7eec8b9

Browse files
haojin2ashokei
authored andcommitted
[MXNET-92] Support float16 in L2Normalization operator (apache#10078)
* enable other dtype in l2 normalization * Get rid of older code * address code reviews: get rid of unnecessary checks * address code reviews * fix buggy InferType in L2Normalization * address code review: change atol
1 parent 6f73281 commit 7eec8b9

File tree

4 files changed

+109
-68
lines changed

4 files changed

+109
-68
lines changed

src/operator/l2_normalization-inl.h

Lines changed: 83 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct L2NormalizationParam : public dmlc::Parameter<L2NormalizationParam> {
6666
* \brief This is the implementation of l2 normalization operator.
6767
* \tparam xpu The device that the op will be executed on.
6868
*/
69-
template<typename xpu>
69+
template<typename xpu, typename DType>
7070
class L2NormalizationOp : public Operator {
7171
public:
7272
explicit L2NormalizationOp(L2NormalizationParam p) {
@@ -89,41 +89,53 @@ class L2NormalizationOp : public Operator {
8989
if (param_.mode == l2_normalization::kInstance) {
9090
Shape<2> dshape = Shape2(orig_shape[0],
9191
orig_shape.ProdShape(1, orig_shape.ndim()));
92-
Tensor<xpu, 2> data = in_data[l2_normalization::kData]
93-
.get_with_shape<xpu, 2, real_t>(dshape, s);
94-
Tensor<xpu, 2> out = out_data[l2_normalization::kOut]
95-
.get_with_shape<xpu, 2, real_t>(dshape, s);
96-
Tensor<xpu, 1> norm = out_data[l2_normalization::kNorm].get<xpu, 1, real_t>(s);
92+
Tensor<xpu, 2, DType> data = in_data[l2_normalization::kData]
93+
.get_with_shape<xpu, 2, DType>(dshape, s);
94+
Tensor<xpu, 2, DType> out = out_data[l2_normalization::kOut]
95+
.get_with_shape<xpu, 2, DType>(dshape, s);
96+
Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s);
9797
norm = sumall_except_dim<0>(F<mxnet::op::mshadow_op::square>(data));
98-
norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps);
98+
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
99+
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
100+
s, norm.size(0), norm.dptr_, norm.dptr_, DType(param_.eps));
101+
});
102+
norm = F<mxnet::op::mshadow_op::square_root>(norm);
99103
out = data / broadcast<0>(norm, out.shape_);
100104
} else if (param_.mode == l2_normalization::kChannel) {
101105
CHECK_GE(orig_shape.ndim(), 3U);
102106
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
103107
orig_shape.ProdShape(2, orig_shape.ndim()));
104-
Tensor<xpu, 3> data = in_data[l2_normalization::kData]
105-
.get_with_shape<xpu, 3, real_t>(dshape, s);
106-
Tensor<xpu, 3> out = out_data[l2_normalization::kOut]
107-
.get_with_shape<xpu, 3, real_t>(dshape, s);
108+
Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
109+
.get_with_shape<xpu, 3, DType>(dshape, s);
110+
Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut]
111+
.get_with_shape<xpu, 3, DType>(dshape, s);
108112
Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
109-
Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
110-
.get_with_shape<xpu, 2, real_t>(norm_shape, s);
113+
Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
114+
.get_with_shape<xpu, 2, DType>(norm_shape, s);
111115
norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 1);
112-
norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps);
116+
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
117+
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
118+
s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps));
119+
});
120+
norm = F<mxnet::op::mshadow_op::square_root>(norm);
113121
out = data / broadcast_with_axis(norm, 0, orig_shape[1]);
114122
} else if (param_.mode == l2_normalization::kSpatial) {
115123
CHECK_GE(orig_shape.ndim(), 3U);
116124
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
117125
orig_shape.ProdShape(2, orig_shape.ndim()));
118-
Tensor<xpu, 3> data = in_data[l2_normalization::kData]
119-
.get_with_shape<xpu, 3, real_t>(dshape, s);
120-
Tensor<xpu, 3> out = out_data[l2_normalization::kOut]
121-
.get_with_shape<xpu, 3, real_t>(dshape, s);
126+
Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
127+
.get_with_shape<xpu, 3, DType>(dshape, s);
128+
Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut]
129+
.get_with_shape<xpu, 3, DType>(dshape, s);
122130
Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
123-
Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
124-
.get_with_shape<xpu, 2, real_t>(norm_shape, s);
131+
Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
132+
.get_with_shape<xpu, 2, DType>(norm_shape, s);
125133
norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 2);
126-
norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps);
134+
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
135+
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
136+
s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps));
137+
});
138+
norm = F<mxnet::op::mshadow_op::square_root>(norm);
127139
out = data / broadcast_with_axis(norm, 1, dshape[2]);
128140
} else {
129141
LOG(FATAL) << "Unexpected mode in l2 normalization";
@@ -148,15 +160,15 @@ class L2NormalizationOp : public Operator {
148160
if (param_.mode == l2_normalization::kInstance) {
149161
Shape<2> dshape = Shape2(orig_shape[0],
150162
orig_shape.ProdShape(1, orig_shape.ndim()));
151-
Tensor<xpu, 2> data = out_data[l2_normalization::kOut]
152-
.get_with_shape<xpu, 2, real_t>(dshape, s);
153-
Tensor<xpu, 2> grad_in = in_grad[l2_normalization::kData]
154-
.get_with_shape<xpu, 2, real_t>(dshape, s);
155-
Tensor<xpu, 2> grad_out = out_grad[l2_normalization::kOut]
156-
.get_with_shape<xpu, 2, real_t>(dshape, s);
157-
Tensor<xpu, 1> norm = out_data[l2_normalization::kNorm].get<xpu, 1, real_t>(s);
158-
Tensor<xpu, 1> temp = ctx.requested[l2_normalization::kTempSpace]
159-
.get_space<xpu>(mshadow::Shape1(data.shape_[0]), s);
163+
Tensor<xpu, 2, DType> data = out_data[l2_normalization::kOut]
164+
.get_with_shape<xpu, 2, DType>(dshape, s);
165+
Tensor<xpu, 2, DType> grad_in = in_grad[l2_normalization::kData]
166+
.get_with_shape<xpu, 2, DType>(dshape, s);
167+
Tensor<xpu, 2, DType> grad_out = out_grad[l2_normalization::kOut]
168+
.get_with_shape<xpu, 2, DType>(dshape, s);
169+
Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s);
170+
Tensor<xpu, 1, DType> temp = ctx.requested[l2_normalization::kTempSpace]
171+
.get_space_typed<xpu, 1, DType>(mshadow::Shape1(data.shape_[0]), s);
160172
temp = sumall_except_dim<0>(grad_out * data);
161173
Assign(grad_in, req[l2_normalization::kData],
162174
(grad_out - data * broadcast<0>(temp, data.shape_)) /
@@ -165,17 +177,17 @@ class L2NormalizationOp : public Operator {
165177
CHECK_GE(orig_shape.ndim(), 3U);
166178
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
167179
orig_shape.ProdShape(2, orig_shape.ndim()));
168-
Tensor<xpu, 3> data = out_data[l2_normalization::kOut]
169-
.get_with_shape<xpu, 3, real_t>(dshape, s);
170-
Tensor<xpu, 3> grad_in = in_grad[l2_normalization::kData]
171-
.get_with_shape<xpu, 3, real_t>(dshape, s);
172-
Tensor<xpu, 3> grad_out = out_grad[l2_normalization::kOut]
173-
.get_with_shape<xpu, 3, real_t>(dshape, s);
180+
Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
181+
.get_with_shape<xpu, 3, DType>(dshape, s);
182+
Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData]
183+
.get_with_shape<xpu, 3, DType>(dshape, s);
184+
Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut]
185+
.get_with_shape<xpu, 3, DType>(dshape, s);
174186
Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
175-
Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
176-
.get_with_shape<xpu, 2, real_t>(norm_shape, s);
177-
Tensor<xpu, 2> temp = ctx.requested[l2_normalization::kTempSpace]
178-
.get_space<xpu>(mshadow::Shape2(data.shape_[0], data.shape_[2]), s);
187+
Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
188+
.get_with_shape<xpu, 2, DType>(norm_shape, s);
189+
Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace]
190+
.get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[2]), s);
179191
temp = reduce_with_axis<red::sum, false>(grad_out * data, 1);
180192
Assign(grad_in, req[l2_normalization::kData],
181193
(grad_out - data * broadcast_with_axis(temp, 0, orig_shape[1])) /
@@ -184,17 +196,17 @@ class L2NormalizationOp : public Operator {
184196
CHECK_GE(orig_shape.ndim(), 3U);
185197
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
186198
orig_shape.ProdShape(2, orig_shape.ndim()));
187-
Tensor<xpu, 3> data = out_data[l2_normalization::kOut]
188-
.get_with_shape<xpu, 3, real_t>(dshape, s);
189-
Tensor<xpu, 3> grad_in = in_grad[l2_normalization::kData]
190-
.get_with_shape<xpu, 3, real_t>(dshape, s);
191-
Tensor<xpu, 3> grad_out = out_grad[l2_normalization::kOut]
192-
.get_with_shape<xpu, 3, real_t>(dshape, s);
199+
Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
200+
.get_with_shape<xpu, 3, DType>(dshape, s);
201+
Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData]
202+
.get_with_shape<xpu, 3, DType>(dshape, s);
203+
Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut]
204+
.get_with_shape<xpu, 3, DType>(dshape, s);
193205
Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
194-
Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
195-
.get_with_shape<xpu, 2, real_t>(norm_shape, s);
196-
Tensor<xpu, 2> temp = ctx.requested[l2_normalization::kTempSpace]
197-
.get_space<xpu>(mshadow::Shape2(data.shape_[0], data.shape_[1]), s);
206+
Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
207+
.get_with_shape<xpu, 2, DType>(norm_shape, s);
208+
Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace]
209+
.get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[1]), s);
198210
temp = reduce_with_axis<red::sum, false>(grad_out * data, 2);
199211
Assign(grad_in, req[l2_normalization::kData],
200212
(grad_out - data * broadcast_with_axis(temp, 1, dshape[2])) /
@@ -210,7 +222,7 @@ class L2NormalizationOp : public Operator {
210222

211223
// Decalre Factory function, used for dispatch specialization
212224
template<typename xpu>
213-
Operator* CreateOp(L2NormalizationParam param);
225+
Operator* CreateOp(L2NormalizationParam param, int dtype);
214226

215227
#if DMLC_USE_CXX11
216228
class L2NormalizationProp : public OperatorProperty {
@@ -235,6 +247,19 @@ class L2NormalizationProp : public OperatorProperty {
235247
return param_.__DICT__();
236248
}
237249

250+
bool InferType(std::vector<int> *in_type,
251+
std::vector<int> *out_type,
252+
std::vector<int> *aux_type) const override {
253+
int dtype = (*in_type)[0];
254+
type_assign(&dtype, (*out_type)[0]);
255+
type_assign(&dtype, (*out_type)[1]);
256+
257+
TYPE_ASSIGN_CHECK(*in_type, 0, dtype);
258+
TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
259+
TYPE_ASSIGN_CHECK(*out_type, 1, dtype);
260+
return dtype != -1;
261+
}
262+
238263
bool InferShape(std::vector<TShape> *in_shape,
239264
std::vector<TShape> *out_shape,
240265
std::vector<TShape> *aux_shape) const override {
@@ -294,7 +319,13 @@ class L2NormalizationProp : public OperatorProperty {
294319
return {ResourceRequest::kTempSpace};
295320
}
296321

297-
Operator* CreateOperator(Context ctx) const override;
322+
Operator* CreateOperator(Context ctx) const override {
323+
LOG(FATAL) << "Not Implemented.";
324+
return NULL;
325+
}
326+
327+
Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
328+
std::vector<int> *in_type) const override;
298329

299330
private:
300331
L2NormalizationParam param_;

src/operator/l2_normalization.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,18 @@
2626
namespace mxnet {
2727
namespace op {
2828
template<>
29-
Operator* CreateOp<cpu>(L2NormalizationParam param) {
30-
return new L2NormalizationOp<cpu>(param);
29+
Operator* CreateOp<cpu>(L2NormalizationParam param, int dtype) {
30+
Operator* op = NULL;
31+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
32+
op = new L2NormalizationOp<cpu, DType>(param);
33+
});
34+
return op;
3135
}
3236

3337
// DO_BIND_DISPATCH comes from static_operator_common.h
34-
Operator* L2NormalizationProp::CreateOperator(Context ctx) const {
35-
DO_BIND_DISPATCH(CreateOp, param_);
38+
Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
39+
std::vector<int> *in_type) const {
40+
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
3641
}
3742

3843
DMLC_REGISTER_PARAMETER(L2NormalizationParam);

src/operator/l2_normalization.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@
2626
namespace mxnet {
2727
namespace op {
2828
template<>
29-
Operator* CreateOp<gpu>(L2NormalizationParam param) {
30-
return new L2NormalizationOp<gpu>(param);
29+
Operator* CreateOp<gpu>(L2NormalizationParam param, int dtype) {
30+
Operator* op = NULL;
31+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
32+
op = new L2NormalizationOp<gpu, DType>(param);
33+
});
34+
return op;
3135
}
3236
} // namespace op
3337
} // namespace mxnet

tests/python/unittest/test_operator.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2391,11 +2391,11 @@ def test_instance_normalization():
23912391
check_instance_norm_with_shape((3,3,2,3,2,1,1), default_context())
23922392

23932393

2394-
def check_l2_normalization(in_shape, mode, norm_eps=1e-10):
2394+
def check_l2_normalization(in_shape, mode, dtype, norm_eps=1e-10):
23952395
ctx = default_context()
23962396
data = mx.symbol.Variable('data')
23972397
out = mx.symbol.L2Normalization(data=data, mode=mode, eps=norm_eps)
2398-
in_data = np.random.uniform(-1, 1, in_shape)
2398+
in_data = np.random.uniform(-1, 1, in_shape).astype(dtype)
23992399
# calculate numpy results
24002400
if mode == 'channel':
24012401
assert in_data.ndim > 2
@@ -2419,21 +2419,22 @@ def check_l2_normalization(in_shape, mode, norm_eps=1e-10):
24192419
exe = out.simple_bind(ctx=ctx, data=in_data.shape)
24202420
output = exe.forward(is_train=True, data=in_data)
24212421
# compare numpy + mxnet
2422-
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-5)
2422+
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-2 if dtype is 'float16' else 1e-5, atol=1e-5)
24232423
# check gradient
24242424
check_numeric_gradient(out, [in_data], numeric_eps=1e-3, rtol=1e-2, atol=1e-3)
24252425

24262426

24272427
# TODO(szha): Seeding this masks failures. We need to do a deep dive for failures without this seed.
24282428
@with_seed(1234)
24292429
def test_l2_normalization():
2430-
for mode in ['channel', 'spatial', 'instance']:
2431-
for nbatch in [1, 4]:
2432-
for nchannel in [3, 5]:
2433-
for height in [4, 6]:
2434-
check_l2_normalization((nbatch, nchannel, height), mode)
2435-
for width in [5, 7]:
2436-
check_l2_normalization((nbatch, nchannel, height, width), mode)
2430+
for dtype in ['float16', 'float32', 'float64']:
2431+
for mode in ['channel', 'spatial', 'instance']:
2432+
for nbatch in [1, 4]:
2433+
for nchannel in [3, 5]:
2434+
for height in [4, 6]:
2435+
check_l2_normalization((nbatch, nchannel, height), mode, dtype)
2436+
for width in [5, 7]:
2437+
check_l2_normalization((nbatch, nchannel, height, width), mode, dtype)
24372438

24382439

24392440
def check_layer_normalization(in_shape, axis, eps, dtype=np.float32, forward_check_eps=1E-3):

0 commit comments

Comments
 (0)