Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7b847e5

Browse files
committed
enable other DTypes in ElementWiseSum
1 parent 4605518 commit 7b847e5

File tree

4 files changed

+88
-23
lines changed

4 files changed

+88
-23
lines changed

src/operator/elementwise_sum-inl.h

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> {
3434
}
3535
};
3636

37-
template<typename xpu>
37+
template<typename xpu, typename DType>
3838
class ElementWiseSumOp : public Operator {
3939
public:
4040
explicit ElementWiseSumOp(ElementWiseSumParam param)
@@ -52,34 +52,34 @@ class ElementWiseSumOp : public Operator {
5252
if (req[elemsum::kOut] == kNullOp) return;
5353

5454
Stream<xpu> *s = ctx.get_stream<xpu>();
55-
Tensor<xpu, 2> out = out_data[elemsum::kOut].FlatTo2D<xpu, real_t>(s);
55+
Tensor<xpu, 2, DType> out = out_data[elemsum::kOut].FlatTo2D<xpu, DType>(s);
5656
switch (size_) {
5757
case 2: {
58-
Tensor<xpu, 2> in_0 = in_data[elemsum::kData0].FlatTo2D<xpu, real_t>(s);
59-
Tensor<xpu, 2> in_1 = in_data[elemsum::kData1].FlatTo2D<xpu, real_t>(s);
58+
Tensor<xpu, 2, DType> in_0 = in_data[elemsum::kData0].FlatTo2D<xpu, DType>(s);
59+
Tensor<xpu, 2, DType> in_1 = in_data[elemsum::kData1].FlatTo2D<xpu, DType>(s);
6060
Assign(out, req[elemsum::kOut], in_0 + in_1);
6161
break;
6262
}
6363
case 3: {
64-
Tensor<xpu, 2> in_0 = in_data[elemsum::kData0].FlatTo2D<xpu, real_t>(s);
65-
Tensor<xpu, 2> in_1 = in_data[elemsum::kData1].FlatTo2D<xpu, real_t>(s);
66-
Tensor<xpu, 2> in_2 = in_data[elemsum::kData2].FlatTo2D<xpu, real_t>(s);
64+
Tensor<xpu, 2, DType> in_0 = in_data[elemsum::kData0].FlatTo2D<xpu, DType>(s);
65+
Tensor<xpu, 2, DType> in_1 = in_data[elemsum::kData1].FlatTo2D<xpu, DType>(s);
66+
Tensor<xpu, 2, DType> in_2 = in_data[elemsum::kData2].FlatTo2D<xpu, DType>(s);
6767
Assign(out, req[elemsum::kOut], in_0 + in_1 + in_2);
6868
break;
6969
}
7070
case 4: {
71-
Tensor<xpu, 2> in_0 = in_data[elemsum::kData0].FlatTo2D<xpu, real_t>(s);
72-
Tensor<xpu, 2> in_1 = in_data[elemsum::kData1].FlatTo2D<xpu, real_t>(s);
73-
Tensor<xpu, 2> in_2 = in_data[elemsum::kData2].FlatTo2D<xpu, real_t>(s);
74-
Tensor<xpu, 2> in_3 = in_data[elemsum::kData3].FlatTo2D<xpu, real_t>(s);
71+
Tensor<xpu, 2, DType> in_0 = in_data[elemsum::kData0].FlatTo2D<xpu, DType>(s);
72+
Tensor<xpu, 2, DType> in_1 = in_data[elemsum::kData1].FlatTo2D<xpu, DType>(s);
73+
Tensor<xpu, 2, DType> in_2 = in_data[elemsum::kData2].FlatTo2D<xpu, DType>(s);
74+
Tensor<xpu, 2, DType> in_3 = in_data[elemsum::kData3].FlatTo2D<xpu, DType>(s);
7575
Assign(out, req[elemsum::kOut], in_0 + in_1 + in_2 + in_3);
7676
break;
7777
}
7878
default: {
79-
Tensor<xpu, 2> in_0 = in_data[elemsum::kData0].FlatTo2D<xpu, real_t>(s);
79+
Tensor<xpu, 2, DType> in_0 = in_data[elemsum::kData0].FlatTo2D<xpu, DType>(s);
8080
Assign(out, req[elemsum::kOut], F<mshadow_op::identity>(in_0));
8181
for (int i = 1; i < size_; ++i) {
82-
out += in_data[i].FlatTo2D<xpu, real_t>(s);
82+
out += in_data[i].FlatTo2D<xpu, DType>(s);
8383
}
8484
break;
8585
}
@@ -97,10 +97,10 @@ class ElementWiseSumOp : public Operator {
9797
using namespace mshadow::expr;
9898
CHECK_EQ(in_grad.size(), static_cast<size_t>(size_));
9999
Stream<xpu> *s = ctx.get_stream<xpu>();
100-
Tensor<xpu, 2> ograd = out_grad[elemsum::kOut].FlatTo2D<xpu, real_t>(s);
100+
Tensor<xpu, 2, DType> ograd = out_grad[elemsum::kOut].FlatTo2D<xpu, DType>(s);
101101
for (int i = 0; i < size_; ++i) {
102102
if (req[i] == kNullOp || req[i] == kWriteInplace) continue;
103-
Tensor<xpu, 2> igrad = in_grad[i].FlatTo2D<xpu, real_t>(s);
103+
Tensor<xpu, 2, DType> igrad = in_grad[i].FlatTo2D<xpu, DType>(s);
104104
Assign(igrad, req[i], F<mshadow_op::identity>(ograd));
105105
}
106106
}
@@ -120,7 +120,7 @@ class ElementWiseSumOp : public Operator {
120120
}; // class ElementWiseSumOp
121121

122122
template<typename xpu>
123-
Operator* CreateOp(ElementWiseSumParam param);
123+
Operator* CreateOp(ElementWiseSumParam param, int dtype);
124124

125125
#if DMLC_USE_CXX11
126126
class ElementWiseSumProp : public OperatorProperty {
@@ -155,6 +155,36 @@ class ElementWiseSumProp : public OperatorProperty {
155155
return true;
156156
}
157157

158+
bool InferType(std::vector<int> *in_type,
159+
std::vector<int> *out_type,
160+
std::vector<int> *aux_type) const override {
161+
size_t nin = in_type->size();
162+
CHECK_EQ(nin, static_cast<size_t>(param_.num_args));
163+
164+
int dtype = -1;
165+
for (size_t i = 0; i < nin; ++i) {
166+
if (dtype == -1) {
167+
dtype = in_type->at(i);
168+
} else {
169+
CHECK(in_type->at(i) == dtype ||
170+
in_type->at(i) == -1) <<
171+
"This operator requires uniform type";
172+
}
173+
}
174+
175+
if (dtype == -1) {
176+
LOG(FATAL) << "At least one input type needs to be known";
177+
return false;
178+
}
179+
180+
in_type->clear();
181+
for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype);
182+
183+
out_type->clear();
184+
out_type->push_back(dtype);
185+
return true;
186+
}
187+
158188
std::vector<std::string> ListArguments() const override {
159189
std::vector<std::string> ret;
160190
for (int i = 0; i < param_.num_args; ++i) {
@@ -194,7 +224,13 @@ class ElementWiseSumProp : public OperatorProperty {
194224
return {{in_data[0], out_data[0]}};
195225
}
196226

197-
Operator* CreateOperator(Context ctx) const override;
227+
Operator* CreateOperator(Context ctx) const override {
228+
LOG(FATAL) << "Not Implemented";
229+
return NULL;
230+
}
231+
232+
Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
233+
std::vector<int> *in_type) const override;
198234

199235
private:
200236
ElementWiseSumParam param_;

src/operator/elementwise_sum.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@
77
namespace mxnet {
88
namespace op {
99
template<>
10-
Operator* CreateOp<cpu>(ElementWiseSumParam param) {
11-
return new ElementWiseSumOp<cpu>(param);
10+
Operator* CreateOp<cpu>(ElementWiseSumParam param, int dtype) {
11+
Operator *op = NULL;
12+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
13+
op = new ElementWiseSumOp<cpu, DType>(param);
14+
});
15+
return op;
1216
}
1317

1418
// DO_BIND_DISPATCH comes from static_operator_common.h
15-
Operator* ElementWiseSumProp::CreateOperator(Context ctx) const {
16-
DO_BIND_DISPATCH(CreateOp, param_);
19+
Operator* ElementWiseSumProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
20+
std::vector<int> *in_type) const {
21+
std::vector<TShape> out_shape, aux_shape;
22+
std::vector<int> out_type, aux_type;
23+
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
24+
CHECK(InferType(in_type, &out_type, &aux_type));
25+
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
1726
}
1827

1928
DMLC_REGISTER_PARAMETER(ElementWiseSumParam);

src/operator/elementwise_sum.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
namespace mxnet {
88
namespace op {
99
template<>
10-
Operator* CreateOp<gpu>(ElementWiseSumParam param) {
11-
return new ElementWiseSumOp<gpu>(param);
10+
Operator* CreateOp<gpu>(ElementWiseSumParam param, int dtype) {
11+
Operator *op = NULL;
12+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
13+
op = new ElementWiseSumOp<gpu, DType>(param);
14+
});
15+
return op;
1216
}
1317
} // namespace op
1418
} // namespace mxnet

tests/python/gpu/test_operator_gpu.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,21 @@ def test_concat_with_type():
105105
'type_dict': {'concat_arg0': np.float32, 'concat_arg1': np.float32}}]
106106
check_consistency(sym, ctx_list)
107107

108+
def test_elementwisesum_with_type():
109+
sym = mx.sym.ElementWiseSum(name='ews', num_args=2)
110+
ctx_list = [{'ctx': mx.gpu(0), 'ews_arg1': (2, 10), 'ews_arg0': (2, 10),
111+
'type_dict': {'ews_arg0': np.float64, 'ews_arg1': np.float64}},
112+
{'ctx': mx.gpu(0), 'ews_arg1': (2, 10), 'ews_arg0': (2, 10),
113+
'type_dict': {'ews_arg0': np.float32, 'ews_arg1': np.float32}},
114+
{'ctx': mx.gpu(0), 'ews_arg1': (2, 10), 'ews_arg0': (2, 10),
115+
'type_dict': {'ews_arg0': np.float16, 'ews_arg1': np.float16}},
116+
{'ctx': mx.cpu(0), 'ews_arg1': (2, 10), 'ews_arg0': (2, 10),
117+
'type_dict': {'ews_arg0': np.float64, 'ews_arg1': np.float64}},
118+
{'ctx': mx.cpu(0), 'ews_arg1': (2, 10), 'ews_arg0': (2, 10),
119+
'type_dict': {'ews_arg0': np.float32, 'ews_arg1': np.float32}}]
120+
check_consistency(sym, ctx_list)
121+
122+
108123
def test_reshape_with_type():
109124
sym = mx.sym.Reshape(name='reshape', shape=(-1,1,1,0))
110125
ctx_list = [{'ctx': mx.gpu(0), 'reshape_data': (2, 2, 2, 10), 'type_dict': {'reshape_data': np.float64}},
@@ -156,6 +171,7 @@ def test_activation_with_type():
156171
test_deconvolution_with_type()
157172
test_upsampling_with_type()
158173
test_concat_with_type()
174+
test_elementwisesum_with_type()
159175
test_reshape_with_type()
160176
test_blockgrad_with_type()
161177
test_swapaxis_with_type()

0 commit comments

Comments
 (0)