diff --git a/example/automatic-mixed-precision/amp_model_conversion.py b/example/automatic-mixed-precision/amp_model_conversion.py index fcc2ad69dd62..b363e0244a10 100644 --- a/example/automatic-mixed-precision/amp_model_conversion.py +++ b/example/automatic-mixed-precision/amp_model_conversion.py @@ -58,13 +58,129 @@ def save_params(fname, arg_params, aux_params, logger=None): 'imagenet1k-resnext-101-64x4d', 'imagenet11k-place365ch-resnet-152', 'imagenet11k-place365ch-resnet-50'] - gluon_models = ['resnet18_v1', + # Faster RCNN and Mask RCNN commented because of model loading issues + # https://github.com/dmlc/gluon-cv/issues/1034 + gluon_models = [#'faster_rcnn_fpn_resnet50_v1b_coco', + 'mobilenetv2_0.75', + 'cifar_resnet56_v1', + 'mobilenet0.25', + 'mobilenet1.0', + #'mask_rcnn_fpn_resnet50_v1b_coco', + 'simple_pose_resnet152_v1b', + 'ssd_512_resnet50_v1_voc', + #'faster_rcnn_resnet50_v1b_voc', + 'cifar_resnet20_v1', + 'yolo3_darknet53_voc', + 'resnet101_v1c', + 'simple_pose_resnet18_v1b', + #'mask_rcnn_resnet50_v1b_coco', + 'ssd_512_mobilenet1.0_coco', + 'vgg19_bn', + #'faster_rcnn_resnet50_v1b_coco', + 'cifar_resnet110_v1', + 'yolo3_mobilenet1.0_voc', + 'cifar_resnext29_16x64d', + 'resnet34_v1', + 'densenet121', + #'mask_rcnn_fpn_resnet101_v1d_coco', + 'vgg13_bn', + 'vgg19', + 'resnet152_v1d', + 'resnet152_v1s', + 'densenet201', + 'alexnet', + 'se_resnext50_32x4d', + 'resnet50_v1d_0.86', + 'resnet18_v1b_0.89', + 'yolo3_darknet53_coco', + 'resnet152_v1', + 'resnext101_64x4d', + 'vgg13', + 'resnet101_v1d_0.76', + 'simple_pose_resnet50_v1d', + 'senet_154', 'resnet50_v1', - 'resnet101_v1', + 'se_resnext101_32x4d', + 'fcn_resnet101_voc', + 'resnet152_v2', + #'mask_rcnn_resnet101_v1d_coco', + 'squeezenet1.1', + 'mobilenet0.5', + 'resnet34_v2', + 'resnet18_v1', + 'resnet152_v1b', + 'resnet101_v2', + 'cifar_resnet56_v2', + 'ssd_512_resnet101_v2_voc', + 'resnet50_v1d_0.37', + 'mobilenetv2_0.5', + #'faster_rcnn_fpn_bn_resnet50_v1b_coco', + 'resnet50_v1c', + 'densenet161', + 'simple_pose_resnet50_v1b', + 'resnet18_v1b', + 'darknet53', + 'fcn_resnet50_ade', + 'cifar_wideresnet28_10', + 'simple_pose_resnet101_v1d', + 'vgg16', + 'ssd_512_resnet50_v1_coco', + 'resnet101_v1d_0.73', 'squeezenet1.0', - 'mobilenet1.0', + 'resnet50_v1b', + #'faster_rcnn_resnet101_v1d_coco', + 'ssd_512_mobilenet1.0_voc', + 'cifar_wideresnet40_8', + 'cifar_wideresnet16_10', + 'cifar_resnet110_v2', + 'resnet101_v1s', + 'mobilenetv2_0.25', + 'resnet152_v1c', + 'se_resnext101_64x4d', + #'faster_rcnn_fpn_resnet101_v1d_coco', + 'resnet50_v1d', + 'densenet169', + 'resnet34_v1b', + 'resnext50_32x4d', + 'resnet101_v1', + 'resnet101_v1b', + 'resnet50_v1s', + 'mobilenet0.75', + 'cifar_resnet20_v2', + 'resnet101_v1d', + 'vgg11_bn', + 'resnet18_v2', + 'vgg11', + 'simple_pose_resnet101_v1b', + 'resnext101_32x4d', + 'resnet50_v2', + 'vgg16_bn', 'mobilenetv2_1.0', - 'inceptionv3'] + 'resnet50_v1d_0.48', + 'resnet50_v1d_0.11', + 'fcn_resnet101_ade', + 'simple_pose_resnet152_v1d', + 'yolo3_mobilenet1.0_coco', + 'fcn_resnet101_coco'] + # TODO(anisub): add support for other models from gluoncv + # Not supported today mostly because of broken net.forward calls + segmentation_models = ['deeplab_resnet50_ade', + 'psp_resnet101_voc', + 'deeplab_resnet152_voc', + 'deeplab_resnet101_ade', + 'deeplab_resnet152_coco', + 'psp_resnet101_ade', + 'deeplab_resnet101_coco', + 'psp_resnet101_citys', + 'psp_resnet50_ade', + 'psp_resnet101_coco', + 'deeplab_resnet101_voc'] + calib_ssd_models = ["ssd_512_vgg16_atrous_voc", + "ssd_300_vgg16_atrous_voc", + "ssd_300_vgg16_atrous_coco"] + calib_inception_models = ["inceptionv3"] + gluon_models = gluon_models + segmentation_models + \ + calib_ssd_models + calib_inception_models models = symbolic_models + gluon_models parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model') @@ -106,14 +222,23 @@ def save_params(fname, arg_params, aux_params, logger=None): else: assert args.model in gluon_models, "Please choose one of the available gluon models: {} \ If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models) + shape = None + if args.model in segmentation_models: + shape = (1, 3, 480, 480) + elif args.model in calib_ssd_models: + shape = (1, 3, 512, 544) + elif args.model in calib_inception_models: + shape = (1, 3, 299, 299) + else: + shape = (1, 3, 224, 224) net = gluoncv.model_zoo.get_model(args.model, pretrained=True) net.hybridize() - result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224))) + result_before1 = net.forward(mx.nd.random.uniform(shape=shape)) net.export("{}".format(args.model)) net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params) net.export("{}-amp".format(args.model), remove_amp_cast=False) if args.run_dummy_inference: logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1") - result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) - result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) + result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0))) + result_after = net.forward(mx.nd.random.uniform(shape=shape, dtype=np.float32, ctx=mx.gpu(0))) logger.info("Inference run successfully") diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index dc83a4b1f87f..6711297718b2 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -116,10 +116,10 @@ template -inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs, - const AttrType& none) { +inline bool ElemwiseAttrHelper(const std::string& node_name, + std::vector *in_attrs, + std::vector *out_attrs, + const AttrType& none) { AttrType dattr = none; size_t in_size = in_attrs->size(); size_t out_size = out_attrs->size(); @@ -133,7 +133,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, auto deduce = [&](const std::vector& vec, size_t size, const char *name) { for (size_t i = 0; i < size; ++i) { CHECK(assign(&dattr, vec.at(i))) - << "Incompatible attr in node " << attrs.name << " at " << i << "-th " + << "Incompatible attr in node " << node_name << " at " << i << "-th " << name << ": " << "expected " << attr_string(dattr) << ", got " << attr_string(vec.at(i)); } @@ -145,7 +145,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, auto write = [&](std::vector *vec, size_t size, const char *name) { for (size_t i = 0; i < size; ++i) { CHECK(assign(&(vec->at(i)), dattr)) - << "Incompatible attr in node " << attrs.name << " at " << i << "-th " + << "Incompatible attr in node " << node_name << " at " << i << "-th " << name << ": " << "expected " << attr_string(dattr) << ", got " << attr_string(vec->at(i)); } @@ -158,6 +158,21 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, return true; } + +template +inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs, + const AttrType& none) { + return ElemwiseAttrHelper(attrs.name, in_attrs, out_attrs, none); +} + template inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, diff --git a/src/operator/slice_channel-inl.h b/src/operator/slice_channel-inl.h index e37ffdcf1b91..6f2aa2f4d17c 100644 --- a/src/operator/slice_channel-inl.h +++ b/src/operator/slice_channel-inl.h @@ -36,6 +36,7 @@ #include #include "./operator_common.h" #include "./channel_op_common.h" +#include "./elemwise_op_common.h" namespace mxnet { namespace op { @@ -176,16 +177,10 @@ class SliceChannelProp : public OperatorProperty { bool InferType(std::vector *in_type, std::vector *out_type, std::vector *aux_type) const override { - CHECK_EQ(in_type->size(), 1U); - int dtype = (*in_type)[0]; - CHECK_NE(dtype, -1) << "First input must have specified type"; - out_type->clear(); - out_type->reserve(param_.num_outputs); - for (int i = 0; i < param_.num_outputs; ++i) { - out_type->push_back(dtype); - } - aux_type->clear(); - return true; + std::string node_name = "slice_channel_node"; + return ElemwiseAttrHelper(node_name, in_type, out_type, -1); } bool InferShape(mxnet::ShapeVector *in_shape, diff --git a/tests/python/gpu/test_contrib_amp.py b/tests/python/gpu/test_contrib_amp.py index c49fa9b49865..74fb29c3f6f6 100644 --- a/tests/python/gpu/test_contrib_amp.py +++ b/tests/python/gpu/test_contrib_amp.py @@ -475,6 +475,15 @@ def test_fp16_casting(): exe = final_res.simple_bind(ctx=mx.gpu(), data=(1, 2), data2=(1, 2)) assert exe.arg_arrays[0].dtype == np.float16 + # Check for symbol which has slice channel + data = mx.sym.var("data") + data2 = mx.sym.var("data2") + data._set_attr(__dtype__="-1") + data2._set_attr(__dtype__="-1") + concat_res = mx.sym.concat(data, data2) + out = mx.sym.split(concat_res, axis=1, num_outputs=2) + final_res = amp.convert_symbol(out) + if __name__ == '__main__': import nose