Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit a37dcd4

Browse files
authored
Fix SliceChannel Type inference (#16748)
* Refactor elemwise_op_common and change SliceChannel InferType * Add gluoncv models * Comment Faster RCNN models
1 parent 5dfa121 commit a37dcd4

File tree

4 files changed

+167
-23
lines changed

4 files changed

+167
-23
lines changed

example/automatic-mixed-precision/amp_model_conversion.py

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,129 @@ def save_params(fname, arg_params, aux_params, logger=None):
5858
'imagenet1k-resnext-101-64x4d',
5959
'imagenet11k-place365ch-resnet-152',
6060
'imagenet11k-place365ch-resnet-50']
61-
gluon_models = ['resnet18_v1',
61+
# Faster RCNN and Mask RCNN commented because of model loading issues
62+
# https://github.com/dmlc/gluon-cv/issues/1034
63+
gluon_models = [#'faster_rcnn_fpn_resnet50_v1b_coco',
64+
'mobilenetv2_0.75',
65+
'cifar_resnet56_v1',
66+
'mobilenet0.25',
67+
'mobilenet1.0',
68+
#'mask_rcnn_fpn_resnet50_v1b_coco',
69+
'simple_pose_resnet152_v1b',
70+
'ssd_512_resnet50_v1_voc',
71+
#'faster_rcnn_resnet50_v1b_voc',
72+
'cifar_resnet20_v1',
73+
'yolo3_darknet53_voc',
74+
'resnet101_v1c',
75+
'simple_pose_resnet18_v1b',
76+
#'mask_rcnn_resnet50_v1b_coco',
77+
'ssd_512_mobilenet1.0_coco',
78+
'vgg19_bn',
79+
#'faster_rcnn_resnet50_v1b_coco',
80+
'cifar_resnet110_v1',
81+
'yolo3_mobilenet1.0_voc',
82+
'cifar_resnext29_16x64d',
83+
'resnet34_v1',
84+
'densenet121',
85+
#'mask_rcnn_fpn_resnet101_v1d_coco',
86+
'vgg13_bn',
87+
'vgg19',
88+
'resnet152_v1d',
89+
'resnet152_v1s',
90+
'densenet201',
91+
'alexnet',
92+
'se_resnext50_32x4d',
93+
'resnet50_v1d_0.86',
94+
'resnet18_v1b_0.89',
95+
'yolo3_darknet53_coco',
96+
'resnet152_v1',
97+
'resnext101_64x4d',
98+
'vgg13',
99+
'resnet101_v1d_0.76',
100+
'simple_pose_resnet50_v1d',
101+
'senet_154',
62102
'resnet50_v1',
63-
'resnet101_v1',
103+
'se_resnext101_32x4d',
104+
'fcn_resnet101_voc',
105+
'resnet152_v2',
106+
#'mask_rcnn_resnet101_v1d_coco',
107+
'squeezenet1.1',
108+
'mobilenet0.5',
109+
'resnet34_v2',
110+
'resnet18_v1',
111+
'resnet152_v1b',
112+
'resnet101_v2',
113+
'cifar_resnet56_v2',
114+
'ssd_512_resnet101_v2_voc',
115+
'resnet50_v1d_0.37',
116+
'mobilenetv2_0.5',
117+
#'faster_rcnn_fpn_bn_resnet50_v1b_coco',
118+
'resnet50_v1c',
119+
'densenet161',
120+
'simple_pose_resnet50_v1b',
121+
'resnet18_v1b',
122+
'darknet53',
123+
'fcn_resnet50_ade',
124+
'cifar_wideresnet28_10',
125+
'simple_pose_resnet101_v1d',
126+
'vgg16',
127+
'ssd_512_resnet50_v1_coco',
128+
'resnet101_v1d_0.73',
64129
'squeezenet1.0',
65-
'mobilenet1.0',
130+
'resnet50_v1b',
131+
#'faster_rcnn_resnet101_v1d_coco',
132+
'ssd_512_mobilenet1.0_voc',
133+
'cifar_wideresnet40_8',
134+
'cifar_wideresnet16_10',
135+
'cifar_resnet110_v2',
136+
'resnet101_v1s',
137+
'mobilenetv2_0.25',
138+
'resnet152_v1c',
139+
'se_resnext101_64x4d',
140+
#'faster_rcnn_fpn_resnet101_v1d_coco',
141+
'resnet50_v1d',
142+
'densenet169',
143+
'resnet34_v1b',
144+
'resnext50_32x4d',
145+
'resnet101_v1',
146+
'resnet101_v1b',
147+
'resnet50_v1s',
148+
'mobilenet0.75',
149+
'cifar_resnet20_v2',
150+
'resnet101_v1d',
151+
'vgg11_bn',
152+
'resnet18_v2',
153+
'vgg11',
154+
'simple_pose_resnet101_v1b',
155+
'resnext101_32x4d',
156+
'resnet50_v2',
157+
'vgg16_bn',
66158
'mobilenetv2_1.0',
67-
'inceptionv3']
159+
'resnet50_v1d_0.48',
160+
'resnet50_v1d_0.11',
161+
'fcn_resnet101_ade',
162+
'simple_pose_resnet152_v1d',
163+
'yolo3_mobilenet1.0_coco',
164+
'fcn_resnet101_coco']
165+
# TODO(anisub): add support for other models from gluoncv
166+
# Not supported today mostly because of broken net.forward calls
167+
segmentation_models = ['deeplab_resnet50_ade',
168+
'psp_resnet101_voc',
169+
'deeplab_resnet152_voc',
170+
'deeplab_resnet101_ade',
171+
'deeplab_resnet152_coco',
172+
'psp_resnet101_ade',
173+
'deeplab_resnet101_coco',
174+
'psp_resnet101_citys',
175+
'psp_resnet50_ade',
176+
'psp_resnet101_coco',
177+
'deeplab_resnet101_voc']
178+
calib_ssd_models = ["ssd_512_vgg16_atrous_voc",
179+
"ssd_300_vgg16_atrous_voc",
180+
"ssd_300_vgg16_atrous_coco"]
181+
calib_inception_models = ["inceptionv3"]
182+
gluon_models = gluon_models + segmentation_models + \
183+
calib_ssd_models + calib_inception_models
68184
models = symbolic_models + gluon_models
69185

70186
parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model')
@@ -106,14 +222,23 @@ def save_params(fname, arg_params, aux_params, logger=None):
106222
else:
107223
assert args.model in gluon_models, "Please choose one of the available gluon models: {} \
108224
If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models)
225+
shape = None
226+
if args.model in segmentation_models:
227+
shape = (1, 3, 480, 480)
228+
elif args.model in calib_ssd_models:
229+
shape = (1, 3, 512, 544)
230+
elif args.model in calib_inception_models:
231+
shape = (1, 3, 299, 299)
232+
else:
233+
shape = (1, 3, 224, 224)
109234
net = gluoncv.model_zoo.get_model(args.model, pretrained=True)
110235
net.hybridize()
111-
result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224)))
236+
result_before1 = net.forward(mx.nd.random.uniform(shape=shape))
112237
net.export("{}".format(args.model))
113238
net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params)
114239
net.export("{}-amp".format(args.model), remove_amp_cast=False)
115240
if args.run_dummy_inference:
116241
logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1")
117-
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
118-
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
242+
result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0)))
243+
result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0)))
119244
logger.info("Inference run successfully")

src/operator/elemwise_op_common.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ template<typename AttrType, bool (*is_none)(const AttrType&),
116116
bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
117117
std::string (*attr_string)(const AttrType&),
118118
index_t n_in = -1, index_t n_out = -1>
119-
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
120-
std::vector<AttrType> *in_attrs,
121-
std::vector<AttrType> *out_attrs,
122-
const AttrType& none) {
119+
inline bool ElemwiseAttrHelper(const std::string& node_name,
120+
std::vector<AttrType> *in_attrs,
121+
std::vector<AttrType> *out_attrs,
122+
const AttrType& none) {
123123
AttrType dattr = none;
124124
size_t in_size = in_attrs->size();
125125
size_t out_size = out_attrs->size();
@@ -133,7 +133,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
133133
auto deduce = [&](const std::vector<AttrType>& vec, size_t size, const char *name) {
134134
for (size_t i = 0; i < size; ++i) {
135135
CHECK(assign(&dattr, vec.at(i)))
136-
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
136+
<< "Incompatible attr in node " << node_name << " at " << i << "-th "
137137
<< name << ": " << "expected " << attr_string(dattr)
138138
<< ", got " << attr_string(vec.at(i));
139139
}
@@ -145,7 +145,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
145145
auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
146146
for (size_t i = 0; i < size; ++i) {
147147
CHECK(assign(&(vec->at(i)), dattr))
148-
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
148+
<< "Incompatible attr in node " << node_name << " at " << i << "-th "
149149
<< name << ": " << "expected " << attr_string(dattr)
150150
<< ", got " << attr_string(vec->at(i));
151151
}
@@ -158,6 +158,21 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
158158
return true;
159159
}
160160

161+
162+
template<typename AttrType, bool (*is_none)(const AttrType&),
163+
bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
164+
std::string (*attr_string)(const AttrType&),
165+
index_t n_in = -1, index_t n_out = -1>
166+
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
167+
std::vector<AttrType> *in_attrs,
168+
std::vector<AttrType> *out_attrs,
169+
const AttrType& none) {
170+
return ElemwiseAttrHelper<AttrType, is_none,
171+
assign, reverse_infer,
172+
attr_string, n_in,
173+
n_out>(attrs.name, in_attrs, out_attrs, none);
174+
}
175+
161176
template<index_t n_in, index_t n_out>
162177
inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
163178
mxnet::ShapeVector *in_attrs,

src/operator/slice_channel-inl.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <utility>
3737
#include "./operator_common.h"
3838
#include "./channel_op_common.h"
39+
#include "./elemwise_op_common.h"
3940

4041
namespace mxnet {
4142
namespace op {
@@ -176,16 +177,10 @@ class SliceChannelProp : public OperatorProperty {
176177
bool InferType(std::vector<int> *in_type,
177178
std::vector<int> *out_type,
178179
std::vector<int> *aux_type) const override {
179-
CHECK_EQ(in_type->size(), 1U);
180-
int dtype = (*in_type)[0];
181-
CHECK_NE(dtype, -1) << "First input must have specified type";
182-
out_type->clear();
183-
out_type->reserve(param_.num_outputs);
184-
for (int i = 0; i < param_.num_outputs; ++i) {
185-
out_type->push_back(dtype);
186-
}
187-
aux_type->clear();
188-
return true;
180+
std::string node_name = "slice_channel_node";
181+
return ElemwiseAttrHelper<int, type_is_none,
182+
type_assign, true,
183+
type_string, 1>(node_name, in_type, out_type, -1);
189184
}
190185

191186
bool InferShape(mxnet::ShapeVector *in_shape,

tests/python/gpu/test_contrib_amp.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,15 @@ def test_fp16_casting():
475475
exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
476476
assert exe.arg_arrays[0].dtype == np.float16
477477

478+
# Check for symbol which has slice channel
479+
data = mx.sym.var("data")
480+
data2 = mx.sym.var("data2")
481+
data._set_attr(__dtype__="-1")
482+
data2._set_attr(__dtype__="-1")
483+
concat_res = mx.sym.concat(data, data2)
484+
out = mx.sym.split(concat_res, axis=1, num_outputs=2)
485+
final_res = amp.convert_symbol(out)
486+
478487

479488
if __name__ == '__main__':
480489
import nose

0 commit comments

Comments
 (0)