Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4605518

Browse files
committed
enable other DTypes in SwapAxis
1 parent 22ca494 commit 4605518

File tree

4 files changed

+51
-11
lines changed

4 files changed

+51
-11
lines changed

src/operator/swapaxis-inl.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct SwapAxisParam : public dmlc::Parameter<SwapAxisParam> {
4040
};
4141

4242

43-
template<typename xpu>
43+
template<typename xpu, typename DType>
4444
class SwapAxisOp : public Operator {
4545
public:
4646
explicit SwapAxisOp(SwapAxisParam p) {
@@ -99,12 +99,12 @@ class SwapAxisOp : public Operator {
9999

100100
Reshape2Five(&inter_shape, shape_in, dim1, dim2);
101101

102-
Tensor<xpu, 5> inter_data_in = data_in.get_with_shape<xpu, 5, real_t>(inter_shape, s);
102+
Tensor<xpu, 5, DType> inter_data_in = data_in.get_with_shape<xpu, 5, DType>(inter_shape, s);
103103

104104
Shape<5> inter_shape2 = inter_shape;
105105
std::swap(inter_shape2[1], inter_shape2[3]);
106106

107-
Tensor<xpu, 5> inter_data_out = data_out.get_with_shape<xpu, 5, real_t>(inter_shape2, s);
107+
Tensor<xpu, 5, DType> inter_data_out = data_out.get_with_shape<xpu, 5, DType>(inter_shape2, s);
108108

109109
inter_data_out = swapaxis<3, 1>(inter_data_in);
110110
}
@@ -138,7 +138,7 @@ class SwapAxisOp : public Operator {
138138

139139

140140
template<typename xpu>
141-
Operator* CreateOp(SwapAxisParam param);
141+
Operator* CreateOp(SwapAxisParam param, int dtype);
142142

143143

144144
#if DMLC_USE_CXX11
@@ -171,6 +171,17 @@ class SwapAxisProp : public OperatorProperty {
171171
return true;
172172
}
173173

174+
bool InferType(std::vector<int> *in_type,
175+
std::vector<int> *out_type,
176+
std::vector<int> *aux_type) const override {
177+
CHECK_EQ(in_type->size(), 1);
178+
int dtype = (*in_type)[0];
179+
CHECK_NE(dtype, -1) << "Input must have specified type";
180+
out_type->clear();
181+
out_type->push_back(dtype);
182+
return true;
183+
}
184+
174185
OperatorProperty* Copy() const override {
175186
auto ptr = new SwapAxisProp();
176187
ptr->param_ = param_;
@@ -188,7 +199,13 @@ class SwapAxisProp : public OperatorProperty {
188199
return {out_grad[swapaxisenum::kOut]};
189200
};
190201

191-
Operator* CreateOperator(Context ctx) const override;
202+
Operator* CreateOperator(Context ctx) const override {
203+
LOG(FATAL) << "Not Implemented";
204+
return NULL;
205+
}
206+
207+
Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
208+
std::vector<int> *in_type) const override;
192209

193210
private:
194211
SwapAxisParam param_;

src/operator/swapaxis.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,21 @@ namespace mxnet {
1111
namespace op {
1212

1313
template<>
14-
Operator* CreateOp<cpu>(SwapAxisParam param) {
15-
return new SwapAxisOp<cpu>(param);
14+
Operator* CreateOp<cpu>(SwapAxisParam param, int dtype) {
15+
Operator *op = NULL;
16+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
17+
op = new SwapAxisOp<cpu, DType>(param);
18+
});
19+
return op;
1620
}
1721

18-
Operator* SwapAxisProp::CreateOperator(Context ctx) const {
19-
DO_BIND_DISPATCH(CreateOp, param_);
22+
Operator* SwapAxisProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
23+
std::vector<int> *in_type) const {
24+
std::vector<TShape> out_shape, aux_shape;
25+
std::vector<int> out_type, aux_type;
26+
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
27+
CHECK(InferType(in_type, &out_type, &aux_type));
28+
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
2029
}
2130

2231

src/operator/swapaxis.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@ namespace mxnet {
1111
namespace op {
1212

1313
template<>
14-
Operator *CreateOp<gpu>(SwapAxisParam param) {
15-
return new SwapAxisOp<gpu>(param);
14+
Operator *CreateOp<gpu>(SwapAxisParam param, int dtype) {
15+
Operator *op = NULL;
16+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
17+
op = new SwapAxisOp<gpu, DType>(param);
18+
});
19+
return op;
1620
}
1721

1822
} // namespace op

tests/python/gpu/test_operator_gpu.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ def test_blockgrad_with_type():
123123
{'ctx': mx.cpu(0), 'bg_data': (2, 2, 2, 10), 'type_dict': {'bg_data': np.float32}}]
124124
check_consistency(sym, ctx_list)
125125

126+
def test_swapaxis_with_type():
127+
sym = mx.sym.SwapAxis(name='swap', dim1=1)
128+
ctx_list = [{'ctx': mx.gpu(0), 'swap_data': (2, 2, 2, 10), 'type_dict': {'swap_data': np.float64}},
129+
{'ctx': mx.gpu(0), 'swap_data': (2, 2, 2, 10), 'type_dict': {'swap_data': np.float32}},
130+
{'ctx': mx.gpu(0), 'swap_data': (2, 2, 2, 10), 'type_dict': {'swap_data': np.float16}},
131+
{'ctx': mx.cpu(0), 'swap_data': (2, 2, 2, 10), 'type_dict': {'swap_data': np.float64}},
132+
{'ctx': mx.cpu(0), 'swap_data': (2, 2, 2, 10), 'type_dict': {'swap_data': np.float32}}]
133+
check_consistency(sym, ctx_list)
134+
126135
def test_fullyconnected_with_type():
127136
sym = mx.sym.FullyConnected(num_hidden=3, name='inner')
128137
ctx_list = [{'ctx': mx.gpu(0), 'inner_data': (2, 10), 'type_dict': {'inner_data': np.float64}},
@@ -149,6 +158,7 @@ def test_activation_with_type():
149158
test_concat_with_type()
150159
test_reshape_with_type()
151160
test_blockgrad_with_type()
161+
test_swapaxis_with_type()
152162
test_fullyconnected_with_type()
153163
test_activation_with_type()
154164
#test_softmax_with_shape((3,4), mx.gpu())

0 commit comments

Comments
 (0)