This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversion from FP32 model to Mixed Precision model #15118
Merged
Merged
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
36e5579
Initial AMP commit
anirudh2290 70409d0
Fix
anirudh2290 ae3734f
Merge AMP Changes
anirudh2290 9f041cc
AMP Changes to support conditional op names switch
anirudh2290 4dce69e
Add example and fix issues with AMP conversion
anirudh2290 8d63335
Remove amp convert symbol test
anirudh2290 e526c16
Fix comment for inference use case
anirudh2290 888daa7
Remove input_names for convert_hybrid_block
anirudh2290 ea8b220
Check all conditions
anirudh2290 eded365
Fix lint
anirudh2290 be5d0dd
Fix error_str for load_dict
anirudh2290 3e8ca54
Fix lint, Add tests, fix bugs, add examples
anirudh2290 7640f50
Fix warnings
anirudh2290 42967e8
Add license for example script
anirudh2290 f502d74
Remove gpu test and move tests to test_contrib_amp
anirudh2290 f7d051d
Clean up AMP tests
anirudh2290 57060e7
Add additional comments, add tutorial
anirudh2290 5a4b1f7
Move the test to gpu dir
anirudh2290 7e1feae
Make the code python3 compatible
anirudh2290 ea7dd32
Upgrade archive utility, fixes: #15084
anirudh2290 94156b6
Allow AR path to be chosen by user
anirudh2290 9ef7bd3
Use current_context in tutorial
anirudh2290 b5173b9
Update __all__
anirudh2290 f9d09a4
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 9297970
Merge with load params API changes
anirudh2290 eb186d0
Revert "Allow AR path to be chosen by user"
anirudh2290 80dd7bc
Revert "Upgrade archive utility, fixes: #15084"
anirudh2290 1ea508f
Set numpy dtype to float32
anirudh2290 8e52789
Address review comments
anirudh2290 61e942f
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 bba14e0
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 9c72372
Add range based for
anirudh2290 89ea0cc
Change quantized to low precision
anirudh2290 ed1b814
Fix lint
anirudh2290 65ebc74
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
anirudh2290 43bc0c9
Fix pylint
anirudh2290 b70ab2f
Forward args for Node::Create
anirudh2290 ed82db1
Fixes
anirudh2290 1200dde
Add dtype casting wherever needed
anirudh2290 85a50e2
Fix lint in source
anirudh2290 22d3a76
Add cast_optional_params to example
anirudh2290 383d664
Tweak example
anirudh2290 2480273
Add README
anirudh2290 8df637d
Add README
anirudh2290 9903222
Add cast_optional_params test for convert_model and convert_hybrid_bloc
anirudh2290 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
<!--- Licensed to the Apache Software Foundation (ASF) under one --> | ||
<!--- or more contributor license agreements. See the NOTICE file --> | ||
<!--- distributed with this work for additional information --> | ||
<!--- regarding copyright ownership. The ASF licenses this file --> | ||
<!--- to you under the Apache License, Version 2.0 (the --> | ||
<!--- "License"); you may not use this file except in compliance --> | ||
<!--- with the License. You may obtain a copy of the License at --> | ||
|
||
<!--- http://www.apache.org/licenses/LICENSE-2.0 --> | ||
|
||
<!--- Unless required by applicable law or agreed to in writing, --> | ||
<!--- software distributed under the License is distributed on an --> | ||
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> | ||
<!--- KIND, either express or implied. See the License for the --> | ||
<!--- specific language governing permissions and limitations --> | ||
<!--- under the License. --> | ||
|
||
# Conversion of FP32 models to Mixed Precision Models | ||
|
||
|
||
This folder contains examples for converting FP32 models to mixed precision models. The script allows for converting FP32 symbolic models or gluon models to mixed precision model. | ||
|
||
## Basic Usages | ||
|
||
1. AMP Model Conversion for a gluon model, casting the params wherever possible to FP16. The below script will convert the `resnet101_v1` model to Mixed Precision Model and cast params to FP16 wherever possible, load this converted model and run inference on it. | ||
|
||
```bash | ||
python amp_model_conversion.py --model resnet101_v1 --use-gluon-model --run-dummy-inference --cast-optional-params | ||
``` | ||
|
||
2. AMP Model Conversion for a symbolic model, keeping the params in FP32 wherever possible (--cast-optional-params not used). | ||
|
||
```bash | ||
python amp_model_conversion.py --model imagenet1k-resnet-152 --run-dummy-inference | ||
``` |
119 changes: 119 additions & 0 deletions
119
example/automatic-mixed-precision/amp_model_conversion.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import os | ||
import logging | ||
import argparse | ||
import mxnet as mx | ||
from common import modelzoo | ||
import gluoncv | ||
from gluoncv.model_zoo import get_model | ||
from mxnet.contrib.amp import amp | ||
import numpy as np | ||
|
||
def download_model(model_name, logger=None): | ||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
model_path = os.path.join(dir_path, 'model') | ||
if logger is not None: | ||
logger.info('Downloading model {}... into path {}'.format(model_name, model_path)) | ||
return modelzoo.download_model(args.model, os.path.join(dir_path, 'model')) | ||
|
||
|
||
def save_symbol(fname, sym, logger=None): | ||
if logger is not None: | ||
logger.info('Saving symbol into file at {}'.format(fname)) | ||
sym.save(fname, remove_amp_cast=False) | ||
|
||
|
||
def save_params(fname, arg_params, aux_params, logger=None): | ||
if logger is not None: | ||
logger.info('Saving params into file at {}'.format(fname)) | ||
save_dict = {('arg:%s' % k): v.as_in_context(mx.cpu()) for k, v in arg_params.items()} | ||
save_dict.update({('aux:%s' % k): v.as_in_context(mx.cpu()) for k, v in aux_params.items()}) | ||
mx.nd.save(fname, save_dict) | ||
|
||
|
||
if __name__ == '__main__': | ||
symbolic_models = ['imagenet1k-resnet-152', | ||
'imagenet1k-resnet-18', | ||
'imagenet1k-resnet-34', | ||
'imagenet1k-resnet-50', | ||
'imagenet1k-resnet-101', | ||
'imagenet1k-resnext-50', | ||
'imagenet1k-resnext-101', | ||
'imagenet1k-resnext-101-64x4d', | ||
'imagenet11k-place365ch-resnet-152', | ||
'imagenet11k-place365ch-resnet-50'] | ||
gluon_models = ['resnet18_v1', | ||
'resnet50_v1', | ||
'resnet101_v1', | ||
'squeezenet1.0', | ||
'mobilenet1.0', | ||
'mobilenetv2_1.0', | ||
'inceptionv3'] | ||
models = symbolic_models + gluon_models | ||
|
||
parser = argparse.ArgumentParser(description='Convert a provided FP32 model to a mixed precision model') | ||
parser.add_argument('--model', type=str, choices=models) | ||
parser.add_argument('--run-dummy-inference', action='store_true', default=False, | ||
help='Will generate random input of shape (1, 3, 224, 224) ' | ||
'and run a dummy inference forward pass') | ||
parser.add_argument('--use-gluon-model', action='store_true', default=False, | ||
help='If enabled, will download pretrained model from Gluon-CV ' | ||
'and convert to mixed precision model ') | ||
parser.add_argument('--cast-optional-params', action='store_true', default=False, | ||
help='If enabled, will try to cast params to target dtype wherever possible') | ||
args = parser.parse_args() | ||
logging.basicConfig() | ||
logger = logging.getLogger('logger') | ||
logger.setLevel(logging.INFO) | ||
|
||
if not args.use_gluon_model: | ||
assert args.model in symbolic_models, "Please choose one of the available symbolic models: {} \ | ||
If you want to use gluon use the script with --use-gluon-model".format(symbolic_models) | ||
|
||
prefix, epoch = download_model(model_name=args.model, logger=logger) | ||
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) | ||
result_sym, result_arg_params, result_aux_params = amp.convert_model(sym, arg_params, aux_params, | ||
cast_optional_params=args.cast_optional_params) | ||
sym_name = "%s-amp-symbol.json" % (prefix) | ||
save_symbol(sym_name, result_sym, logger) | ||
param_name = '%s-%04d.params' % (prefix + '-amp', epoch) | ||
save_params(param_name, result_arg_params, result_aux_params, logger) | ||
if args.run_dummy_inference: | ||
logger.info("Running inference on the mixed precision model with dummy input, batch size: 1") | ||
mod = mx.mod.Module(result_sym, data_names=['data'], label_names=['softmax_label'], context=mx.gpu(0)) | ||
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]], label_shapes=[['softmax_label', (1,)]]) | ||
mod.set_params(arg_params, aux_params) | ||
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], | ||
label=[mx.nd.ones((1,))])) | ||
result = mod.get_outputs()[0].asnumpy() | ||
logger.info("Inference run successfully") | ||
else: | ||
assert args.model in gluon_models, "Please choose one of the available gluon models: {} \ | ||
If you want to use symbolic model instead, remove --use-gluon-model when running the script".format(gluon_models) | ||
net = gluoncv.model_zoo.get_model(args.model, pretrained=True) | ||
net.hybridize() | ||
result_before1 = net.forward(mx.nd.zeros((1, 3, 224, 224))) | ||
net.export("{}".format(args.model)) | ||
net = amp.convert_hybrid_block(net, cast_optional_params=args.cast_optional_params) | ||
net.export("{}-amp".format(args.model), remove_amp_cast=False) | ||
if args.run_dummy_inference: | ||
logger.info("Running inference on the mixed precision model with dummy inputs, batch size: 1") | ||
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) | ||
result_after = net.forward(mx.nd.zeros((1, 3, 224, 224), dtype=np.float32, ctx=mx.gpu(0))) | ||
logger.info("Inference run successfully") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../image-classification/common |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.