Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 9a90c84

Browse files
committed
enable other DTypes in BlockGrad
1 parent 04bb0cb commit 9a90c84

File tree

4 files changed

+53
-12
lines changed

4 files changed

+53
-12
lines changed

src/operator/block_grad-inl.h

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ enum BlockGradientOpInputs {kData};
2424
enum BlockGradientOpOutputs {kOut};
2525
} // namespace blockgrad
2626

27-
template<typename xpu>
27+
template<typename xpu, typename DType>
2828
class BlockGradientOp : public Operator {
2929
public:
3030
virtual void Forward(const OpContext &ctx,
@@ -37,8 +37,8 @@ class BlockGradientOp : public Operator {
3737
CHECK_EQ(in_data.size(), 1);
3838
CHECK_EQ(out_data.size(), 1);
3939
Stream<xpu> *s = ctx.get_stream<xpu>();
40-
Tensor<xpu, 2> data = in_data[blockgrad::kData].FlatTo2D<xpu, real_t>(s);
41-
Tensor<xpu, 2> out = out_data[blockgrad::kOut].FlatTo2D<xpu, real_t>(s);
40+
Tensor<xpu, 2, DType> data = in_data[blockgrad::kData].FlatTo2D<xpu, DType>(s);
41+
Tensor<xpu, 2, DType> out = out_data[blockgrad::kOut].FlatTo2D<xpu, DType>(s);
4242
out = F<mshadow_op::identity>(data);
4343
}
4444

@@ -52,13 +52,13 @@ class BlockGradientOp : public Operator {
5252
using namespace mshadow;
5353
using namespace mshadow::expr;
5454
Stream<xpu> *s = ctx.get_stream<xpu>();
55-
Tensor<xpu, 2> grad = in_grad[blockgrad::kData].FlatTo2D<xpu, real_t>(s);
55+
Tensor<xpu, 2, DType> grad = in_grad[blockgrad::kData].FlatTo2D<xpu, DType>(s);
5656
grad = 0.f;
5757
}
5858
}; // class BlockGradientOp
5959

6060
template<typename xpu>
61-
Operator *CreateOp();
61+
Operator *CreateOp(int dtype);
6262

6363
#if DMLC_USE_CXX11
6464
class BlockGradientProp : public OperatorProperty {
@@ -81,6 +81,18 @@ class BlockGradientProp : public OperatorProperty {
8181
return true;
8282
}
8383

84+
bool InferType(std::vector<int> *in_type,
85+
std::vector<int> *out_type,
86+
std::vector<int> *aux_type) const override {
87+
88+
CHECK_EQ(in_type->size(), 1);
89+
int dtype = (*in_type)[0];
90+
CHECK_NE(dtype, -1) << "Input must have specified type";
91+
out_type->clear();
92+
out_type->push_back(dtype);
93+
return true;
94+
}
95+
8496
OperatorProperty* Copy() const override {
8597
return new BlockGradientProp();
8698
}
@@ -102,7 +114,13 @@ class BlockGradientProp : public OperatorProperty {
102114
return {{in_data[blockgrad::kData], out_data[blockgrad::kOut]}};
103115
}
104116

105-
Operator* CreateOperator(Context ctx) const override;
117+
Operator* CreateOperator(Context ctx) const override {
118+
LOG(FATAL) << "Not Implemented";
119+
return NULL;
120+
}
121+
122+
Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
123+
std::vector<int> *in_type) const override;
106124
}; // class BlockGradientProperty
107125

108126
#endif // DMLC_USE_CXX11

src/operator/block_grad.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,21 @@
99
namespace mxnet {
1010
namespace op {
1111
template<>
12-
Operator *CreateOp<cpu>() {
13-
return new BlockGradientOp<cpu>();
12+
Operator *CreateOp<cpu>(int dtype) {
13+
Operator *op = NULL;
14+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
15+
op = new BlockGradientOp<cpu, DType>();
16+
});
17+
return op;
1418
}
1519

16-
Operator *BlockGradientProp::CreateOperator(Context ctx) const {
17-
DO_BIND_DISPATCH(CreateOp);
20+
Operator *BlockGradientProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
21+
std::vector<int> *in_type) const {
22+
std::vector<TShape> out_shape, aux_shape;
23+
std::vector<int> out_type, aux_type;
24+
CHECK(InferType(in_type, &out_type, &aux_type));
25+
CHECK(InferShape(in_shape, &out_shape, &aux_shape));
26+
DO_BIND_DISPATCH(CreateOp, in_type->at(0));
1827
}
1928

2029
MXNET_REGISTER_OP_PROPERTY(BlockGrad, BlockGradientProp)

src/operator/block_grad.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@
99
namespace mxnet {
1010
namespace op {
1111
template<>
12-
Operator *CreateOp<gpu>() {
13-
return new BlockGradientOp<gpu>();
12+
Operator *CreateOp<gpu>(int dtype) {
13+
Operator *op = NULL;
14+
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
15+
op = new BlockGradientOp<gpu, DType>();
16+
});
17+
return op;
1418
}
1519

1620
} // namespace op

tests/python/gpu/test_operator_gpu.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,15 @@ def test_reshape_with_type():
114114
{'ctx': mx.cpu(0), 'reshape_data': (2, 2, 2, 10), 'type_dict': {'reshape_data': np.float32}}]
115115
check_consistency(sym, ctx_list)
116116

117+
def test_blockgrad_with_type():
118+
sym = mx.sym.BlockGrad(name='bg')
119+
ctx_list = [{'ctx': mx.gpu(0), 'bg_data': (2, 2, 2, 10), 'type_dict': {'bg_data': np.float64}},
120+
{'ctx': mx.gpu(0), 'bg_data': (2, 2, 2, 10), 'type_dict': {'bg_data': np.float32}},
121+
{'ctx': mx.gpu(0), 'bg_data': (2, 2, 2, 10), 'type_dict': {'bg_data': np.float16}},
122+
{'ctx': mx.cpu(0), 'bg_data': (2, 2, 2, 10), 'type_dict': {'bg_data': np.float64}},
123+
{'ctx': mx.cpu(0), 'bg_data': (2, 2, 2, 10), 'type_dict': {'bg_data': np.float32}}]
124+
check_consistency(sym, ctx_list)
125+
117126
def test_fullyconnected_with_type():
118127
sym = mx.sym.FullyConnected(num_hidden=3, name='inner')
119128
ctx_list = [{'ctx': mx.gpu(0), 'inner_data': (2, 10), 'type_dict': {'inner_data': np.float64}},
@@ -139,6 +148,7 @@ def test_activation_with_type():
139148
test_upsampling_with_type()
140149
test_concat_with_type()
141150
test_reshape_with_type()
151+
test_blockgrad_with_type()
142152
test_fullyconnected_with_type()
143153
test_activation_with_type()
144154
#test_softmax_with_shape((3,4), mx.gpu())

0 commit comments

Comments
 (0)