-
Notifications
You must be signed in to change notification settings - Fork 663
[Enhance] support TensorRT engine for onnxruntime #1739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhance] support TensorRT engine for onnxruntime #1739
Conversation
@@ -36,7 +38,8 @@ class ORTWrapper(BaseWrapper): | |||
def __init__(self, | |||
onnx_file: str, | |||
device: str, | |||
output_names: Optional[Sequence[str]] = None): | |||
output_names: Optional[Sequence[str]] = None, | |||
enable_trt: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to check if TensorRT provider is available in the current environment?
If it is, we can check it inside init instead of add a flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> available_providers = onnxruntime.get_available_providers()
>>> available_providers
['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
onnxruntime
has a get_available_providers
function that tells you which providers are available in the current environment. Changed to use TensorRT executor if available. thank you :)
if device == 'cpu': | ||
providers.append('CPUExecutionProvider') | ||
else: | ||
providers.append(('CUDAExecutionProvider', { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know, cpu provider can co-exist with cuda provider, ONNXRuntime will fall back to cpu if the cuda implementation of the op is not provided.
Cpu provider can be placed in the providers list even when we use gpu device. (I am not sure if the order of the provider matters)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the official documentation, the order of providers is the same as preference. Therefore, when the device is cuda, tensorrt and cuda executor are included, and the cpu executor is added at the end as the default value. (https://onnxruntime.ai/docs/execution-providers/)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
With the TensorRT execution provider, the ONNX Runtime delivers better inferencing performance on the same hardware compared to generic GPU acceleration. The TensorRT execution provider in the ONNX Runtime makes use of NVIDIA’s TensorRT Deep Learning inferencing engine to accelerate ONNX model in their family of GPUs.
https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html
Modification
Please briefly describe what modification is made in this PR.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist