@@ -66,7 +66,7 @@ struct L2NormalizationParam : public dmlc::Parameter<L2NormalizationParam> {
66
66
* \brief This is the implementation of l2 normalization operator.
67
67
* \tparam xpu The device that the op will be executed on.
68
68
*/
69
- template <typename xpu>
69
+ template <typename xpu, typename DType >
70
70
class L2NormalizationOp : public Operator {
71
71
public:
72
72
explicit L2NormalizationOp (L2NormalizationParam p) {
@@ -89,41 +89,53 @@ class L2NormalizationOp : public Operator {
89
89
if (param_.mode == l2_normalization::kInstance ) {
90
90
Shape<2 > dshape = Shape2 (orig_shape[0 ],
91
91
orig_shape.ProdShape (1 , orig_shape.ndim ()));
92
- Tensor<xpu, 2 > data = in_data[l2_normalization::kData ]
93
- .get_with_shape <xpu, 2 , real_t >(dshape, s);
94
- Tensor<xpu, 2 > out = out_data[l2_normalization::kOut ]
95
- .get_with_shape <xpu, 2 , real_t >(dshape, s);
96
- Tensor<xpu, 1 > norm = out_data[l2_normalization::kNorm ].get <xpu, 1 , real_t >(s);
92
+ Tensor<xpu, 2 , DType > data = in_data[l2_normalization::kData ]
93
+ .get_with_shape <xpu, 2 , DType >(dshape, s);
94
+ Tensor<xpu, 2 , DType > out = out_data[l2_normalization::kOut ]
95
+ .get_with_shape <xpu, 2 , DType >(dshape, s);
96
+ Tensor<xpu, 1 , DType > norm = out_data[l2_normalization::kNorm ].get <xpu, 1 , DType >(s);
97
97
norm = sumall_except_dim<0 >(F<mxnet::op::mshadow_op::square>(data));
98
- norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps );
98
+ MXNET_ASSIGN_REQ_SWITCH (req[0 ], Req, {
99
+ mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch (
100
+ s, norm.size (0 ), norm.dptr_ , norm.dptr_ , DType (param_.eps ));
101
+ });
102
+ norm = F<mxnet::op::mshadow_op::square_root>(norm);
99
103
out = data / broadcast<0 >(norm, out.shape_ );
100
104
} else if (param_.mode == l2_normalization::kChannel ) {
101
105
CHECK_GE (orig_shape.ndim (), 3U );
102
106
Shape<3 > dshape = Shape3 (orig_shape[0 ], orig_shape[1 ],
103
107
orig_shape.ProdShape (2 , orig_shape.ndim ()));
104
- Tensor<xpu, 3 > data = in_data[l2_normalization::kData ]
105
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
106
- Tensor<xpu, 3 > out = out_data[l2_normalization::kOut ]
107
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
108
+ Tensor<xpu, 3 , DType > data = in_data[l2_normalization::kData ]
109
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
110
+ Tensor<xpu, 3 , DType > out = out_data[l2_normalization::kOut ]
111
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
108
112
Shape<2 > norm_shape = Shape2 (dshape[0 ], dshape[2 ]);
109
- Tensor<xpu, 2 > norm = out_data[l2_normalization::kNorm ]
110
- .get_with_shape <xpu, 2 , real_t >(norm_shape, s);
113
+ Tensor<xpu, 2 , DType > norm = out_data[l2_normalization::kNorm ]
114
+ .get_with_shape <xpu, 2 , DType >(norm_shape, s);
111
115
norm = reduce_with_axis<red::sum, false >(F<mxnet::op::mshadow_op::square>(data), 1 );
112
- norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps );
116
+ MXNET_ASSIGN_REQ_SWITCH (req[0 ], Req, {
117
+ mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch (
118
+ s, norm.size (0 ) * norm.size (1 ), norm.dptr_ , norm.dptr_ , DType (param_.eps ));
119
+ });
120
+ norm = F<mxnet::op::mshadow_op::square_root>(norm);
113
121
out = data / broadcast_with_axis (norm, 0 , orig_shape[1 ]);
114
122
} else if (param_.mode == l2_normalization::kSpatial ) {
115
123
CHECK_GE (orig_shape.ndim (), 3U );
116
124
Shape<3 > dshape = Shape3 (orig_shape[0 ], orig_shape[1 ],
117
125
orig_shape.ProdShape (2 , orig_shape.ndim ()));
118
- Tensor<xpu, 3 > data = in_data[l2_normalization::kData ]
119
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
120
- Tensor<xpu, 3 > out = out_data[l2_normalization::kOut ]
121
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
126
+ Tensor<xpu, 3 , DType > data = in_data[l2_normalization::kData ]
127
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
128
+ Tensor<xpu, 3 , DType > out = out_data[l2_normalization::kOut ]
129
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
122
130
Shape<2 > norm_shape = Shape2 (dshape[0 ], dshape[1 ]);
123
- Tensor<xpu, 2 > norm = out_data[l2_normalization::kNorm ]
124
- .get_with_shape <xpu, 2 , real_t >(norm_shape, s);
131
+ Tensor<xpu, 2 , DType > norm = out_data[l2_normalization::kNorm ]
132
+ .get_with_shape <xpu, 2 , DType >(norm_shape, s);
125
133
norm = reduce_with_axis<red::sum, false >(F<mxnet::op::mshadow_op::square>(data), 2 );
126
- norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps );
134
+ MXNET_ASSIGN_REQ_SWITCH (req[0 ], Req, {
135
+ mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch (
136
+ s, norm.size (0 ) * norm.size (1 ), norm.dptr_ , norm.dptr_ , DType (param_.eps ));
137
+ });
138
+ norm = F<mxnet::op::mshadow_op::square_root>(norm);
127
139
out = data / broadcast_with_axis (norm, 1 , dshape[2 ]);
128
140
} else {
129
141
LOG (FATAL) << " Unexpected mode in l2 normalization" ;
@@ -148,15 +160,15 @@ class L2NormalizationOp : public Operator {
148
160
if (param_.mode == l2_normalization::kInstance ) {
149
161
Shape<2 > dshape = Shape2 (orig_shape[0 ],
150
162
orig_shape.ProdShape (1 , orig_shape.ndim ()));
151
- Tensor<xpu, 2 > data = out_data[l2_normalization::kOut ]
152
- .get_with_shape <xpu, 2 , real_t >(dshape, s);
153
- Tensor<xpu, 2 > grad_in = in_grad[l2_normalization::kData ]
154
- .get_with_shape <xpu, 2 , real_t >(dshape, s);
155
- Tensor<xpu, 2 > grad_out = out_grad[l2_normalization::kOut ]
156
- .get_with_shape <xpu, 2 , real_t >(dshape, s);
157
- Tensor<xpu, 1 > norm = out_data[l2_normalization::kNorm ].get <xpu, 1 , real_t >(s);
158
- Tensor<xpu, 1 > temp = ctx.requested [l2_normalization::kTempSpace ]
159
- .get_space <xpu>(mshadow::Shape1 (data.shape_ [0 ]), s);
163
+ Tensor<xpu, 2 , DType > data = out_data[l2_normalization::kOut ]
164
+ .get_with_shape <xpu, 2 , DType >(dshape, s);
165
+ Tensor<xpu, 2 , DType > grad_in = in_grad[l2_normalization::kData ]
166
+ .get_with_shape <xpu, 2 , DType >(dshape, s);
167
+ Tensor<xpu, 2 , DType > grad_out = out_grad[l2_normalization::kOut ]
168
+ .get_with_shape <xpu, 2 , DType >(dshape, s);
169
+ Tensor<xpu, 1 , DType > norm = out_data[l2_normalization::kNorm ].get <xpu, 1 , DType >(s);
170
+ Tensor<xpu, 1 , DType > temp = ctx.requested [l2_normalization::kTempSpace ]
171
+ .get_space_typed <xpu, 1 , DType >(mshadow::Shape1 (data.shape_ [0 ]), s);
160
172
temp = sumall_except_dim<0 >(grad_out * data);
161
173
Assign (grad_in, req[l2_normalization::kData ],
162
174
(grad_out - data * broadcast<0 >(temp, data.shape_ )) /
@@ -165,17 +177,17 @@ class L2NormalizationOp : public Operator {
165
177
CHECK_GE (orig_shape.ndim (), 3U );
166
178
Shape<3 > dshape = Shape3 (orig_shape[0 ], orig_shape[1 ],
167
179
orig_shape.ProdShape (2 , orig_shape.ndim ()));
168
- Tensor<xpu, 3 > data = out_data[l2_normalization::kOut ]
169
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
170
- Tensor<xpu, 3 > grad_in = in_grad[l2_normalization::kData ]
171
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
172
- Tensor<xpu, 3 > grad_out = out_grad[l2_normalization::kOut ]
173
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
180
+ Tensor<xpu, 3 , DType > data = out_data[l2_normalization::kOut ]
181
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
182
+ Tensor<xpu, 3 , DType > grad_in = in_grad[l2_normalization::kData ]
183
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
184
+ Tensor<xpu, 3 , DType > grad_out = out_grad[l2_normalization::kOut ]
185
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
174
186
Shape<2 > norm_shape = Shape2 (dshape[0 ], dshape[2 ]);
175
- Tensor<xpu, 2 > norm = out_data[l2_normalization::kNorm ]
176
- .get_with_shape <xpu, 2 , real_t >(norm_shape, s);
177
- Tensor<xpu, 2 > temp = ctx.requested [l2_normalization::kTempSpace ]
178
- .get_space <xpu>(mshadow::Shape2 (data.shape_ [0 ], data.shape_ [2 ]), s);
187
+ Tensor<xpu, 2 , DType > norm = out_data[l2_normalization::kNorm ]
188
+ .get_with_shape <xpu, 2 , DType >(norm_shape, s);
189
+ Tensor<xpu, 2 , DType > temp = ctx.requested [l2_normalization::kTempSpace ]
190
+ .get_space_typed <xpu, 2 , DType >(mshadow::Shape2 (data.shape_ [0 ], data.shape_ [2 ]), s);
179
191
temp = reduce_with_axis<red::sum, false >(grad_out * data, 1 );
180
192
Assign (grad_in, req[l2_normalization::kData ],
181
193
(grad_out - data * broadcast_with_axis (temp, 0 , orig_shape[1 ])) /
@@ -184,17 +196,17 @@ class L2NormalizationOp : public Operator {
184
196
CHECK_GE (orig_shape.ndim (), 3U );
185
197
Shape<3 > dshape = Shape3 (orig_shape[0 ], orig_shape[1 ],
186
198
orig_shape.ProdShape (2 , orig_shape.ndim ()));
187
- Tensor<xpu, 3 > data = out_data[l2_normalization::kOut ]
188
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
189
- Tensor<xpu, 3 > grad_in = in_grad[l2_normalization::kData ]
190
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
191
- Tensor<xpu, 3 > grad_out = out_grad[l2_normalization::kOut ]
192
- .get_with_shape <xpu, 3 , real_t >(dshape, s);
199
+ Tensor<xpu, 3 , DType > data = out_data[l2_normalization::kOut ]
200
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
201
+ Tensor<xpu, 3 , DType > grad_in = in_grad[l2_normalization::kData ]
202
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
203
+ Tensor<xpu, 3 , DType > grad_out = out_grad[l2_normalization::kOut ]
204
+ .get_with_shape <xpu, 3 , DType >(dshape, s);
193
205
Shape<2 > norm_shape = Shape2 (dshape[0 ], dshape[1 ]);
194
- Tensor<xpu, 2 > norm = out_data[l2_normalization::kNorm ]
195
- .get_with_shape <xpu, 2 , real_t >(norm_shape, s);
196
- Tensor<xpu, 2 > temp = ctx.requested [l2_normalization::kTempSpace ]
197
- .get_space <xpu>(mshadow::Shape2 (data.shape_ [0 ], data.shape_ [1 ]), s);
206
+ Tensor<xpu, 2 , DType > norm = out_data[l2_normalization::kNorm ]
207
+ .get_with_shape <xpu, 2 , DType >(norm_shape, s);
208
+ Tensor<xpu, 2 , DType > temp = ctx.requested [l2_normalization::kTempSpace ]
209
+ .get_space_typed <xpu, 2 , DType >(mshadow::Shape2 (data.shape_ [0 ], data.shape_ [1 ]), s);
198
210
temp = reduce_with_axis<red::sum, false >(grad_out * data, 2 );
199
211
Assign (grad_in, req[l2_normalization::kData ],
200
212
(grad_out - data * broadcast_with_axis (temp, 1 , dshape[2 ])) /
@@ -210,7 +222,7 @@ class L2NormalizationOp : public Operator {
210
222
211
223
// Decalre Factory function, used for dispatch specialization
212
224
template <typename xpu>
213
- Operator* CreateOp (L2NormalizationParam param);
225
+ Operator* CreateOp (L2NormalizationParam param, int dtype );
214
226
215
227
#if DMLC_USE_CXX11
216
228
class L2NormalizationProp : public OperatorProperty {
@@ -235,6 +247,19 @@ class L2NormalizationProp : public OperatorProperty {
235
247
return param_.__DICT__ ();
236
248
}
237
249
250
+ bool InferType (std::vector<int > *in_type,
251
+ std::vector<int > *out_type,
252
+ std::vector<int > *aux_type) const override {
253
+ int dtype = (*in_type)[0 ];
254
+ type_assign (&dtype, (*out_type)[0 ]);
255
+ type_assign (&dtype, (*out_type)[1 ]);
256
+
257
+ TYPE_ASSIGN_CHECK (*in_type, 0 , dtype);
258
+ TYPE_ASSIGN_CHECK (*out_type, 0 , dtype);
259
+ TYPE_ASSIGN_CHECK (*out_type, 1 , dtype);
260
+ return dtype != -1 ;
261
+ }
262
+
238
263
bool InferShape (std::vector<TShape> *in_shape,
239
264
std::vector<TShape> *out_shape,
240
265
std::vector<TShape> *aux_shape) const override {
@@ -294,7 +319,13 @@ class L2NormalizationProp : public OperatorProperty {
294
319
return {ResourceRequest::kTempSpace };
295
320
}
296
321
297
- Operator* CreateOperator (Context ctx) const override ;
322
+ Operator* CreateOperator (Context ctx) const override {
323
+ LOG (FATAL) << " Not Implemented." ;
324
+ return NULL ;
325
+ }
326
+
327
+ Operator* CreateOperatorEx (Context ctx, std::vector<TShape> *in_shape,
328
+ std::vector<int > *in_type) const override ;
298
329
299
330
private:
300
331
L2NormalizationParam param_;
0 commit comments