@@ -34,7 +34,7 @@ struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> {
34
34
}
35
35
};
36
36
37
- template <typename xpu>
37
+ template <typename xpu, typename DType >
38
38
class ElementWiseSumOp : public Operator {
39
39
public:
40
40
explicit ElementWiseSumOp (ElementWiseSumParam param)
@@ -52,34 +52,34 @@ class ElementWiseSumOp : public Operator {
52
52
if (req[elemsum::kOut ] == kNullOp ) return ;
53
53
54
54
Stream<xpu> *s = ctx.get_stream <xpu>();
55
- Tensor<xpu, 2 > out = out_data[elemsum::kOut ].FlatTo2D <xpu, real_t >(s);
55
+ Tensor<xpu, 2 , DType > out = out_data[elemsum::kOut ].FlatTo2D <xpu, DType >(s);
56
56
switch (size_) {
57
57
case 2 : {
58
- Tensor<xpu, 2 > in_0 = in_data[elemsum::kData0 ].FlatTo2D <xpu, real_t >(s);
59
- Tensor<xpu, 2 > in_1 = in_data[elemsum::kData1 ].FlatTo2D <xpu, real_t >(s);
58
+ Tensor<xpu, 2 , DType > in_0 = in_data[elemsum::kData0 ].FlatTo2D <xpu, DType >(s);
59
+ Tensor<xpu, 2 , DType > in_1 = in_data[elemsum::kData1 ].FlatTo2D <xpu, DType >(s);
60
60
Assign (out, req[elemsum::kOut ], in_0 + in_1);
61
61
break ;
62
62
}
63
63
case 3 : {
64
- Tensor<xpu, 2 > in_0 = in_data[elemsum::kData0 ].FlatTo2D <xpu, real_t >(s);
65
- Tensor<xpu, 2 > in_1 = in_data[elemsum::kData1 ].FlatTo2D <xpu, real_t >(s);
66
- Tensor<xpu, 2 > in_2 = in_data[elemsum::kData2 ].FlatTo2D <xpu, real_t >(s);
64
+ Tensor<xpu, 2 , DType > in_0 = in_data[elemsum::kData0 ].FlatTo2D <xpu, DType >(s);
65
+ Tensor<xpu, 2 , DType > in_1 = in_data[elemsum::kData1 ].FlatTo2D <xpu, DType >(s);
66
+ Tensor<xpu, 2 , DType > in_2 = in_data[elemsum::kData2 ].FlatTo2D <xpu, DType >(s);
67
67
Assign (out, req[elemsum::kOut ], in_0 + in_1 + in_2);
68
68
break ;
69
69
}
70
70
case 4 : {
71
- Tensor<xpu, 2 > in_0 = in_data[elemsum::kData0 ].FlatTo2D <xpu, real_t >(s);
72
- Tensor<xpu, 2 > in_1 = in_data[elemsum::kData1 ].FlatTo2D <xpu, real_t >(s);
73
- Tensor<xpu, 2 > in_2 = in_data[elemsum::kData2 ].FlatTo2D <xpu, real_t >(s);
74
- Tensor<xpu, 2 > in_3 = in_data[elemsum::kData3 ].FlatTo2D <xpu, real_t >(s);
71
+ Tensor<xpu, 2 , DType > in_0 = in_data[elemsum::kData0 ].FlatTo2D <xpu, DType >(s);
72
+ Tensor<xpu, 2 , DType > in_1 = in_data[elemsum::kData1 ].FlatTo2D <xpu, DType >(s);
73
+ Tensor<xpu, 2 , DType > in_2 = in_data[elemsum::kData2 ].FlatTo2D <xpu, DType >(s);
74
+ Tensor<xpu, 2 , DType > in_3 = in_data[elemsum::kData3 ].FlatTo2D <xpu, DType >(s);
75
75
Assign (out, req[elemsum::kOut ], in_0 + in_1 + in_2 + in_3);
76
76
break ;
77
77
}
78
78
default : {
79
- Tensor<xpu, 2 > in_0 = in_data[elemsum::kData0 ].FlatTo2D <xpu, real_t >(s);
79
+ Tensor<xpu, 2 , DType > in_0 = in_data[elemsum::kData0 ].FlatTo2D <xpu, DType >(s);
80
80
Assign (out, req[elemsum::kOut ], F<mshadow_op::identity>(in_0));
81
81
for (int i = 1 ; i < size_; ++i) {
82
- out += in_data[i].FlatTo2D <xpu, real_t >(s);
82
+ out += in_data[i].FlatTo2D <xpu, DType >(s);
83
83
}
84
84
break ;
85
85
}
@@ -97,10 +97,10 @@ class ElementWiseSumOp : public Operator {
97
97
using namespace mshadow ::expr;
98
98
CHECK_EQ (in_grad.size (), static_cast <size_t >(size_));
99
99
Stream<xpu> *s = ctx.get_stream <xpu>();
100
- Tensor<xpu, 2 > ograd = out_grad[elemsum::kOut ].FlatTo2D <xpu, real_t >(s);
100
+ Tensor<xpu, 2 , DType > ograd = out_grad[elemsum::kOut ].FlatTo2D <xpu, DType >(s);
101
101
for (int i = 0 ; i < size_; ++i) {
102
102
if (req[i] == kNullOp || req[i] == kWriteInplace ) continue ;
103
- Tensor<xpu, 2 > igrad = in_grad[i].FlatTo2D <xpu, real_t >(s);
103
+ Tensor<xpu, 2 , DType > igrad = in_grad[i].FlatTo2D <xpu, DType >(s);
104
104
Assign (igrad, req[i], F<mshadow_op::identity>(ograd));
105
105
}
106
106
}
@@ -120,7 +120,7 @@ class ElementWiseSumOp : public Operator {
120
120
}; // class ElementWiseSumOp
121
121
122
122
template <typename xpu>
123
- Operator* CreateOp (ElementWiseSumParam param);
123
+ Operator* CreateOp (ElementWiseSumParam param, int dtype );
124
124
125
125
#if DMLC_USE_CXX11
126
126
class ElementWiseSumProp : public OperatorProperty {
@@ -155,6 +155,36 @@ class ElementWiseSumProp : public OperatorProperty {
155
155
return true ;
156
156
}
157
157
158
+ bool InferType (std::vector<int > *in_type,
159
+ std::vector<int > *out_type,
160
+ std::vector<int > *aux_type) const override {
161
+ size_t nin = in_type->size ();
162
+ CHECK_EQ (nin, static_cast <size_t >(param_.num_args ));
163
+
164
+ int dtype = -1 ;
165
+ for (size_t i = 0 ; i < nin; ++i) {
166
+ if (dtype == -1 ) {
167
+ dtype = in_type->at (i);
168
+ } else {
169
+ CHECK (in_type->at (i) == dtype ||
170
+ in_type->at (i) == -1 ) <<
171
+ " This operator requires uniform type" ;
172
+ }
173
+ }
174
+
175
+ if (dtype == -1 ) {
176
+ LOG (FATAL) << " At least one input type needs to be known" ;
177
+ return false ;
178
+ }
179
+
180
+ in_type->clear ();
181
+ for (size_t i = 0 ; i < nin; ++i) in_type->push_back (dtype);
182
+
183
+ out_type->clear ();
184
+ out_type->push_back (dtype);
185
+ return true ;
186
+ }
187
+
158
188
std::vector<std::string> ListArguments () const override {
159
189
std::vector<std::string> ret;
160
190
for (int i = 0 ; i < param_.num_args ; ++i) {
@@ -194,7 +224,13 @@ class ElementWiseSumProp : public OperatorProperty {
194
224
return {{in_data[0 ], out_data[0 ]}};
195
225
}
196
226
197
- Operator* CreateOperator (Context ctx) const override ;
227
+ Operator* CreateOperator (Context ctx) const override {
228
+ LOG (FATAL) << " Not Implemented" ;
229
+ return NULL ;
230
+ }
231
+
232
+ Operator* CreateOperatorEx (Context ctx, std::vector<TShape> *in_shape,
233
+ std::vector<int > *in_type) const override ;
198
234
199
235
private:
200
236
ElementWiseSumParam param_;
0 commit comments