Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Fix SliceChannel Type inference #16748

Merged
merged 3 commits into from
Nov 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 132 additions & 7 deletions example/automatic-mixed-precision/amp_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,129 @@ def save_params(fname, arg_params, aux_params, logger=None):
'imagenet1k-resnext-101-64x4d',
'imagenet11k-place365ch-resnet-152',
'imagenet11k-place365ch-resnet-50']
gluon_models = ['resnet18_v1',
# Faster RCNN and Mask RCNN commented because of model loading issues
# https://github.com/dmlc/gluon-cv/issues/1034
gluon_models = [#'faster_rcnn_fpn_resnet50_v1b_coco',
'mobilenetv2_0.75',
'cifar_resnet56_v1',
'mobilenet0.25',
'mobilenet1.0',
#'mask_rcnn_fpn_resnet50_v1b_coco',
'simple_pose_resnet152_v1b',
'ssd_512_resnet50_v1_voc',
#'faster_rcnn_resnet50_v1b_voc',
'cifar_resnet20_v1',
'yolo3_darknet53_voc',
'resnet101_v1c',
'simple_pose_resnet18_v1b',
#'mask_rcnn_resnet50_v1b_coco',
'ssd_512_mobilenet1.0_coco',
'vgg19_bn',
#'faster_rcnn_resnet50_v1b_coco',
'cifar_resnet110_v1',
'yolo3_mobilenet1.0_voc',
'cifar_resnext29_16x64d',
'resnet34_v1',
'densenet121',
#'mask_rcnn_fpn_resnet101_v1d_coco',
'vgg13_bn',
'vgg19',
'resnet152_v1d',
'resnet152_v1s',
'densenet201',
'alexnet',
'se_resnext50_32x4d',
'resnet50_v1d_0.86',
'resnet18_v1b_0.89',
'yolo3_darknet53_coco',
'resnet152_v1',
'resnext101_64x4d',
'vgg13',
'resnet101_v1d_0.76',
'simple_pose_resnet50_v1d',
'senet_154',
'resnet50_v1',
'resnet101_v1',
'se_resnext101_32x4d',
'fcn_resnet101_voc',
'resnet152_v2',
#'mask_rcnn_resnet101_v1d_coco',
'squeezenet1.1',
'mobilenet0.5',
'resnet34_v2',
'resnet18_v1',
'resnet152_v1b',
'resnet101_v2',
'cifar_resnet56_v2',
'ssd_512_resnet101_v2_voc',
'resnet50_v1d_0.37',
'mobilenetv2_0.5',
#'faster_rcnn_fpn_bn_resnet50_v1b_coco',
'resnet50_v1c',
'densenet161',
'simple_pose_resnet50_v1b',
'resnet18_v1b',
'darknet53',
'fcn_resnet50_ade',
'cifar_wideresnet28_10',
'simple_pose_resnet101_v1d',
'vgg16',
'ssd_512_resnet50_v1_coco',
'resnet101_v1d_0.73',
'squeezenet1.0',
'mobilenet1.0',
'resnet50_v1b',
#'faster_rcnn_resnet101_v1d_coco',
'ssd_512_mobilenet1.0_voc',
'cifar_wideresnet40_8',
'cifar_wideresnet16_10',
'cifar_resnet110_v2',
'resnet101_v1s',
'mobilenetv2_0.25',
'resnet152_v1c',
'se_resnext101_64x4d',
#'faster_rcnn_fpn_resnet101_v1d_coco',
'resnet50_v1d',
'densenet169',
'resnet34_v1b',
'resnext50_32x4d',
'resnet101_v1',
'resnet101_v1b',
'resnet50_v1s',
'mobilenet0.75',
'cifar_resnet20_v2',
'resnet101_v1d',
'vgg11_bn',
'resnet18_v2',
'vgg11',
'simple_pose_resnet101_v1b',
'resnext101_32x4d',
'resnet50_v2',
'vgg16_bn',
'mobilenetv2_1.0',
'inceptionv3']
'resnet50_v1d_0.48',
'resnet50_v1d_0.11',
'fcn_resnet101_ade',
'simple_pose_resnet152_v1d',
'yolo3_mobilenet1.0_coco',
'fcn_resnet101_coco']
# TODO(anisub): add support for other models from gluoncv
# Not supported today mostly because of broken net.forward calls
segmentation_models = ['deeplab_resnet50_ade',
'psp_resnet101_voc',
'deeplab_resnet152_voc',
'deeplab_resnet101_ade',
'deeplab_resnet152_coco',
'psp_resnet101_ade',
'deeplab_resnet101_coco',
'psp_resnet101_citys',
'psp_resnet50_ade',
'psp_resnet101_coco',
'deeplab_resnet101_voc']
calib_ssd_models = ["ssd_512_vgg16_atrous_voc",
"ssd_300_vgg16_atrous_voc",
"ssd_300_vgg16_atrous_coco"]
calib_inception_models = ["inceptionv3"]
gluon_models = gluon_models + segmentation_models + \
calib_ssd_models + calib_inception_models
models = symbolic_models + gluon_models

parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model')
Expand Down Expand Up @@ -106,14 +222,23 @@ def save_params(fname, arg_params, aux_params, logger=None):
else:
assert args.model in gluon_models, "Please choose one of the available gluon models: {} \
If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models)
shape = None
if args.model in segmentation_models:
shape = (1, 3, 480, 480)
elif args.model in calib_ssd_models:
shape = (1, 3, 512, 544)
elif args.model in calib_inception_models:
shape = (1, 3, 299, 299)
else:
shape = (1, 3, 224, 224)
net = gluoncv.model_zoo.get_model(args.model, pretrained=True)
net.hybridize()
result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224)))
result_before1 = net.forward(mx.nd.random.uniform(shape=shape))
net.export("{}".format(args.model))
net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params)
net.export("{}-amp".format(args.model), remove_amp_cast=False)
if args.run_dummy_inference:
logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1")
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0)))
result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0)))
result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0)))
logger.info("Inference run successfully")
27 changes: 21 additions & 6 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ template<typename AttrType, bool (*is_none)(const AttrType&),
bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
std::string (*attr_string)(const AttrType&),
index_t n_in = -1, index_t n_out = -1>
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType& none) {
inline bool ElemwiseAttrHelper(const std::string& node_name,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType& none) {
AttrType dattr = none;
size_t in_size = in_attrs->size();
size_t out_size = out_attrs->size();
Expand All @@ -133,7 +133,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
auto deduce = [&](const std::vector<AttrType>& vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&dattr, vec.at(i)))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< "Incompatible attr in node " << node_name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string(vec.at(i));
}
Expand All @@ -145,7 +145,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(vec->at(i)), dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< "Incompatible attr in node " << node_name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string(vec->at(i));
}
Expand All @@ -158,6 +158,21 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
return true;
}


template<typename AttrType, bool (*is_none)(const AttrType&),
bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
std::string (*attr_string)(const AttrType&),
index_t n_in = -1, index_t n_out = -1>
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
const AttrType& none) {
return ElemwiseAttrHelper<AttrType, is_none,
assign, reverse_infer,
attr_string, n_in,
n_out>(attrs.name, in_attrs, out_attrs, none);
}

template<index_t n_in, index_t n_out>
inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
Expand Down
15 changes: 5 additions & 10 deletions src/operator/slice_channel-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <utility>
#include "./operator_common.h"
#include "./channel_op_common.h"
#include "./elemwise_op_common.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -176,16 +177,10 @@ class SliceChannelProp : public OperatorProperty {
bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_EQ(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "First input must have specified type";
out_type->clear();
out_type->reserve(param_.num_outputs);
for (int i = 0; i < param_.num_outputs; ++i) {
out_type->push_back(dtype);
}
aux_type->clear();
return true;
std::string node_name = "slice_channel_node";
return ElemwiseAttrHelper<int, type_is_none,
type_assign, true,
type_string, 1>(node_name, in_type, out_type, -1);
}

bool InferShape(mxnet::ShapeVector *in_shape,
Expand Down
9 changes: 9 additions & 0 deletions tests/python/gpu/test_contrib_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,15 @@ def test_fp16_casting():
exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2))
assert exe.arg_arrays[0].dtype == np.float16

# Check for symbol which has slice channel
data = mx.sym.var("data")
data2 = mx.sym.var("data2")
data._set_attr(__dtype__="-1")
data2._set_attr(__dtype__="-1")
concat_res = mx.sym.concat(data, data2)
out = mx.sym.split(concat_res, axis=1, num_outputs=2)
final_res = amp.convert_symbol(out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any checks or assertions needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, this was earlier failing in the convert_symbol call before this change. Just need to check if convert_symbol completes successfully.



if __name__ == '__main__':
import nose
Expand Down