diff --git a/docs/source/onnxruntime/modeling_ort.mdx b/docs/source/onnxruntime/modeling_ort.mdx index 5bbafe4388..09f9a2a8fa 100644 --- a/docs/source/onnxruntime/modeling_ort.mdx +++ b/docs/source/onnxruntime/modeling_ort.mdx @@ -12,13 +12,13 @@ specific language governing permissions and limitations under the License. # Optimum Inference with ONNX Runtime -Optimum is a utility package for building and running inference with accelerated runtime like ONNX Runtime. -Optimum can be used to load optimized models from the [Hugging Face Hub](hf.co/models) and create pipelines +Optimum is a utility package for building and running inference with accelerated runtime like ONNX Runtime. +Optimum can be used to load optimized models from the [Hugging Face Hub](hf.co/models) and create pipelines to run accelerated inference without rewriting your APIs. ## Switching from Transformers to Optimum Inference -The Optimum Inference models are API compatible with Hugging Face Transformers models. This means you can just replace your `AutoModelForXxx` class with the corresponding `ORTModelForXxx` class in `optimum`. For example, this is how you can use a question answering model in `optimum`: +The Optimum Inference models are API compatible with Hugging Face Transformers models. This means you can just replace your `AutoModelForXxx` class with the corresponding `ORTModelForXxx` class in `optimum`. For example, this is how you can use a question answering model in `optimum`: ```diff from transformers import AutoTokenizer, pipeline @@ -57,8 +57,8 @@ You can find a complete walkhrough Optimum Inference for ONNX Runtime in this [n ### Working with the Hugging Face Model Hub -The Optimum model classes like [`~onnxruntime.ORTModelForSequenceClassification`] are integrated with the [Hugging Face Model Hub](https://hf.co/models), which means you can not only -load model from the Hub, but also push your models to the Hub with `push_to_hub()` method. Below is an example which downloads a vanilla Transformers model +The Optimum model classes like [`~onnxruntime.ORTModelForSequenceClassification`] are integrated with the [Hugging Face Model Hub](https://hf.co/models), which means you can not only +load model from the Hub, but also push your models to the Hub with `push_to_hub()` method. Below is an example which downloads a vanilla Transformers model from the Hub and converts it to an optimum onnxruntime model and pushes it back into a new repository. @@ -105,3 +105,7 @@ from the Hub and converts it to an optimum onnxruntime model and pushes it back [[autodoc]] onnxruntime.modeling_ort.ORTModelForCausalLM +## ORTModelForImageClassification + +[[autodoc]] onnxruntime.modeling_ort.ORTModelForImageClassification + diff --git a/docs/source/pipelines.mdx b/docs/source/pipelines.mdx index 14ea3b2fb6..f4b135cb4a 100644 --- a/docs/source/pipelines.mdx +++ b/docs/source/pipelines.mdx @@ -12,8 +12,7 @@ specific language governing permissions and limitations under the License. # Optimum pipelines for inference -The [`~pipelines.pipeline`] function makes it simple to use models from the [Model Hub](https://huggingface.co/models) for accelerated inference on a variety of tasks such as text classification. -Even if you don't have experience with a specific modality or understand the code powering the models, you can still use them with the [`~pipelines.pipeline`] function! +The [`~pipelines.pipeline`] function makes it simple to use models from the [Model Hub](https://huggingface.co/models) for accelerated inference on a variety of tasks such as text classification, question answering and image classification. @@ -31,11 +30,12 @@ Currenlty supported tasks are: * `question-answering` * `zero-shot-classification` * `text-generation` +* `image-classification` ## Optimum pipeline usage -While each task has an associated pipeline class, it is simpler to use the general [`~pipelines.pipeline`] function which wraps all the task-specific pipelines in one object. -The [`~pipelines.pipeline`] function automatically loads a default model and tokenizer capable of inference for your task. +While each task has an associated pipeline class, it is simpler to use the general [`~pipelines.pipeline`] function which wraps all the task-specific pipelines in one object. +The [`~pipelines.pipeline`] function automatically loads a default model and tokenizer/feature-extractor capable of inference for your task. 1. Start by creating a pipeline by specifying an inference task: @@ -46,7 +46,7 @@ The [`~pipelines.pipeline`] function automatically loads a default model and tok ``` -2. Pass your input text to the [`~pipelines.pipeline`] function: +2. Pass your input text/image to the [`~pipelines.pipeline`] function: ```python >>> classifier("I like you. I love you.") @@ -57,9 +57,9 @@ _Note: The default models used in the [`~pipelines.pipeline`] function are not o ### Using vanilla Transformers model and converting to ONNX -The [`~pipelines.pipeline`] function accepts any supported model from the [Model Hub](https://huggingface.co/models). -There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task. -Once you've picked an appropriate model, load it with the `from_pretrained("{model_id}",from_transformers=True)` method associated with the `ORTModelFor*` +The [`~pipelines.pipeline`] function accepts any supported model from the [Model Hub](https://huggingface.co/models). +There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task. +Once you've picked an appropriate model, load it with the `from_pretrained("{model_id}",from_transformers=True)` method associated with the `ORTModelFor*` `AutoTokenizer' class. For example, here's how you can load the [`~onnxruntime.ORTModelForQuestionAnswering`] class for question answering: ```python @@ -80,10 +80,10 @@ Once you've picked an appropriate model, load it with the `from_pretrained("{mod ### Using Optimum models -The [`~pipelines.pipeline`] function is tightly integrated with [Model Hub](https://huggingface.co/models) and can load optimized models directly, e.g. those created with ONNX Runtime. -There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task. +The [`~pipelines.pipeline`] function is tightly integrated with [Model Hub](https://huggingface.co/models) and can load optimized models directly, e.g. those created with ONNX Runtime. +There are tags on the Model Hub that allow you to filter for a model you'd like to use for your task. Once you've picked an appropriate model, load it with the `from_pretrained()` method associated with the corresponding `ORTModelFor*` -and `AutoTokenizer' class. For example, here's how you can load an optimized model for question answering: +and `AutoTokenizer'/`AutoFeatureExtractor` class. For example, here's how you can load an optimized model for question answering: ```python >>> from transformers import AutoTokenizer @@ -132,7 +132,7 @@ Below you can find two examples on how you could [`~onnxruntime.ORTOptimizer`] a onnx_quantized_model_output_path=save_path / "model-quantized.onnx", quantization_config=qconfig, ) ->>> quantizer.model.config.save_pretrained(save_path) # saves config.json +>>> quantizer.model.config.save_pretrained(save_path) # saves config.json # load optimized model from local path or repository >>> model = ORTModelForSequenceClassification.from_pretrained(save_path,file_name="model-quantized.onnx") @@ -176,7 +176,7 @@ Below you can find two examples on how you could [`~onnxruntime.ORTOptimizer`] a onnx_optimized_model_output_path=save_path / "model-optimized.onnx", optimization_config=optimization_config, ) ->>> optimizer.model.config.save_pretrained(save_path) # saves config.json +>>> optimizer.model.config.save_pretrained(save_path) # saves config.json # load optimized model from local path or repository >>> model = ORTModelForSequenceClassification.from_pretrained(save_path,file_name="model-optimized.onnx") @@ -198,8 +198,8 @@ Below you can find two examples on how you could [`~onnxruntime.ORTOptimizer`] a ## Transformers pipeline usage The [`~pipelines.pipeline`] function is just a light wrapper around the `transformers.pipeline` function to enable checks for supported tasks and additional features -, like quantization and optimization. This being said you can use the `transformers.pipeline` and just replace your `AutoFor*` with the optimum - `ORTModelFor*` class. +, like quantization and optimization. This being said you can use the `transformers.pipeline` and just replace your `AutoModelFor*` with the optimum + `ORTModelFor*` class. ```diff from transformers import AutoTokenizer, pipeline @@ -207,7 +207,7 @@ from transformers import AutoTokenizer, pipeline +from optimum.onnxruntime import ORTModelForQuestionAnswering -model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") -+model = ORTModelForQuestionAnswering.from_transformers("optimum/roberta-base-squad2") ++model = ORTModelForQuestionAnswering.from_pretrained("optimum/roberta-base-squad2") tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") onnx_qa = pipeline("question-answering",model=model,tokenizer=tokenizer) diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 20ae1283e5..ecc71e9fd5 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -53,6 +53,7 @@ class ORTQuantizableOperator(Enum): from .modeling_ort import ( ORTModelForCausalLM, ORTModelForFeatureExtraction, + ORTModelForImageClassification, ORTModelForQuestionAnswering, ORTModelForSequenceClassification, ORTModelForTokenClassification, diff --git a/optimum/onnxruntime/configuration.py b/optimum/onnxruntime/configuration.py index a33135aa15..f56fb7726f 100644 --- a/optimum/onnxruntime/configuration.py +++ b/optimum/onnxruntime/configuration.py @@ -20,7 +20,6 @@ from datasets import Dataset from packaging.version import Version, parse -from onnxruntime import GraphOptimizationLevel from onnxruntime import __version__ as ort_version from onnxruntime.quantization import CalibraterBase, CalibrationMethod, QuantFormat, QuantizationMode, QuantType from onnxruntime.quantization.calibrate import create_calibrator diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 1b631a0e6f..70d03ca60b 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -9,10 +9,10 @@ AutoConfig, AutoModel, AutoModelForCausalLM, + AutoModelForImageClassification, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForTokenClassification, - AutoTokenizer, PretrainedConfig, ) from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, default_cache_path @@ -20,11 +20,13 @@ from transformers.modeling_outputs import ( BaseModelOutput, CausalLMOutputWithCrossAttentions, + ImageClassifierOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from transformers.onnx import FeaturesManager, export +from transformers.onnx.utils import get_preprocessor import onnxruntime as ort from huggingface_hub import HfApi, hf_hub_download @@ -37,6 +39,7 @@ _TOKENIZER_FOR_DOC = "AutoTokenizer" +_FEATURE_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" ONNX_MODEL_START_DOCSTRING = r""" This model inherits from [~`onnxruntime.modeling_ort.ORTModel`]. Check the superclass documentation for the generic methods the @@ -52,7 +55,7 @@ Args: input_ids (`torch.Tensor` of shape `({0})`): Indices of input sequence tokens in the vocabulary. - Indices can be obtained using [`AutoTokenizer`](https://huggingface.co/docs/transformers/autoclass_tutorial#autotokenizer). + Indices can be obtained using [`AutoTokenizer`](https://huggingface.co/docs/transformers/autoclass_tutorial#autotokenizer). See [`PreTrainedTokenizer.encode`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.encode) and [`PreTrainedTokenizer.__call__`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.__call__) for details. [What are input IDs?](https://huggingface.co/docs/transformers/glossary#input-ids) @@ -256,14 +259,14 @@ def _from_transformers( task = "default" # 2. convert to temp dir # FIXME: transformers.onnx conversion doesn't support private models - tokenizer = AutoTokenizer.from_pretrained(model_id) + preprocessor = get_preprocessor(model_id) model = FeaturesManager.get_model_from_feature(task, model_id) _, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=task) onnx_config = model_onnx_config(model.config) # export model export( - preprocessor=tokenizer, + preprocessor=preprocessor, model=model, config=onnx_config, opset=onnx_config.default_onnx_opset, @@ -274,7 +277,7 @@ def _from_transformers( return cls._from_pretrained(save_dir.as_posix(), **kwargs) -FEAUTRE_EXTRACTION_SAMPLE = r""" +FEAUTRE_EXTRACTION_EXAMPLE = r""" Example of feature extraction: ```python @@ -323,14 +326,14 @@ class ORTModelForFeatureExtraction(ORTModel): pipeline_task = "default" auto_model_class = AutoModel - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, model=None, config=None, **kwargs): + super().__init__(model, config, **kwargs) # create {name:idx} dict for model outputs self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} @add_start_docstrings_to_model_forward( ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") - + FEAUTRE_EXTRACTION_SAMPLE.format( + + FEAUTRE_EXTRACTION_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="ORTModelForFeatureExtraction", checkpoint="optimum/all-MiniLM-L6-v2", @@ -357,7 +360,7 @@ def forward( return BaseModelOutput(last_hidden_state=last_hidden_state) -QUESTION_ANSWERING_SAMPLE = r""" +QUESTION_ANSWERING_EXAMPLE = r""" Example of question answering: ```python @@ -408,14 +411,14 @@ class ORTModelForQuestionAnswering(ORTModel): pipeline_task = "question-answering" auto_model_class = AutoModelForQuestionAnswering - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, model=None, config=None, **kwargs): + super().__init__(model, config, **kwargs) # create {name:idx} dict for model outputs self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} @add_start_docstrings_to_model_forward( ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") - + QUESTION_ANSWERING_SAMPLE.format( + + QUESTION_ANSWERING_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="ORTModelForQuestionAnswering", checkpoint="optimum/roberta-base-squad2", @@ -446,7 +449,7 @@ def forward( ) -SEQUENCE_CLASSIFICATION_SAMPLE = r""" +SEQUENCE_CLASSIFICATION_EXAMPLE = r""" Example of single-label classification: ```python @@ -511,15 +514,15 @@ class ORTModelForSequenceClassification(ORTModel): pipeline_task = "sequence-classification" auto_model_class = AutoModelForSequenceClassification - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, model=None, config=None, **kwargs): + super().__init__(model, config, **kwargs) # create {name:idx} dict for model outputs self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} self.model_inputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_inputs())} @add_start_docstrings_to_model_forward( ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") - + SEQUENCE_CLASSIFICATION_SAMPLE.format( + + SEQUENCE_CLASSIFICATION_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="ORTModelForSequenceClassification", checkpoint="optimum/distilbert-base-uncased-finetuned-sst-2-english", @@ -549,7 +552,7 @@ def forward( ) -TOKEN_CLASSIFICATION_SAMPLE = r""" +TOKEN_CLASSIFICATION_EXAMPLE = r""" Example of token classification: ```python @@ -599,14 +602,14 @@ class ORTModelForTokenClassification(ORTModel): pipeline_task = "token-classification" auto_model_class = AutoModelForTokenClassification - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, model=None, config=None, **kwargs): + super().__init__(model, config, **kwargs) # create {name:idx} dict for model outputs self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} @add_start_docstrings_to_model_forward( ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") - + TOKEN_CLASSIFICATION_SAMPLE.format( + + TOKEN_CLASSIFICATION_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="ORTModelForTokenClassification", checkpoint="optimum/bert-base-NER", @@ -635,7 +638,7 @@ def forward( ) -TEXT_GENERATION_SAMPLE = r""" +TEXT_GENERATION_EXAMPLE = r""" Example of text generation: ```python @@ -684,8 +687,8 @@ class ORTModelForCausalLM(ORTModel, GenerationMixin): pipeline_task = "causal-lm" auto_model_class = AutoModelForCausalLM - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, model=None, config=None, **kwargs): + super().__init__(model, config, **kwargs) # create {name:idx} dict for model outputs self.main_input_name = "input_ids" self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} @@ -701,7 +704,7 @@ def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) - @add_start_docstrings_to_model_forward( ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") - + TEXT_GENERATION_SAMPLE.format( + + TEXT_GENERATION_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="ORTModelForCausalLM", checkpoint="optimum/gpt2", @@ -748,3 +751,87 @@ def _prepare_attention_mask_for_generation( else: # Ensure attention mask is on the same device as the input IDs return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + + +IMAGE_CLASSIFICATION_EXAMPLE = r""" + Example of image classification: + + ```python + >>> import requests + >>> from PIL import Image + >>> from optimum.onnxruntime import {model_class} + >>> from transformers import {processor_class} + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> preprocessor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = preprocessor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ``` + + Example using `transformers.pipeline`: + + ```python + >>> import requests + >>> from PIL import Image + >>> from transformers import {processor_class}, pipeline + >>> from optimum.onnxruntime import {model_class} + + >>> preprocessor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> onnx_image_classifier = pipeline("image-classification", model=model, feature_extractor=preprocessor) + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> pred = onnx_image_classifier(url) + ``` +""" + + +@add_start_docstrings( + """ + Onnx Model for image-classification tasks. + """, + ONNX_MODEL_START_DOCSTRING, +) +class ORTModelForImageClassification(ORTModel): + """ + Image Classification model for ONNX. + """ + + # used in from_transformers to export model to onnx + pipeline_task = "image-classification" + auto_model_class = AutoModelForImageClassification + + def __init__(self, model=None, config=None, **kwargs): + super().__init__(model, config, **kwargs) + # create {name:idx} dict for model outputs + self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())} + + @add_start_docstrings_to_model_forward( + ONNX_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + FEAUTRE_EXTRACTION_EXAMPLE.format( + processor_class=_FEATURE_EXTRACTOR_FOR_DOC, + model_class="ORTModelForImageClassification", + checkpoint="optimum/vit-base-patch16-224", + ) + ) + def forward( + self, + pixel_values: torch.Tensor, + **kwargs, + ): + # converts pytorch inputs into numpy inputs for onnx + onnx_inputs = { + "pixel_values": pixel_values.cpu().detach().numpy(), + } + # run inference + outputs = self.model.run(None, onnx_inputs) + # converts output to namedtuple for pipelines post-processing + return ImageClassifierOutput( + logits=torch.from_numpy(outputs[self.model_outputs["logits"]]), + ) diff --git a/optimum/onnxruntime/optimization.py b/optimum/onnxruntime/optimization.py index d1e4bacee8..83e7cbd982 100644 --- a/optimum/onnxruntime/optimization.py +++ b/optimum/onnxruntime/optimization.py @@ -16,9 +16,10 @@ from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Union -from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizer +from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel from transformers.onnx import export from transformers.onnx.features import FeaturesManager +from transformers.onnx.utils import get_preprocessor from onnx import load_model from onnxruntime.transformers.fusion_options import FusionOptions @@ -41,7 +42,7 @@ def from_pretrained( model_name_or_path: Union[str, os.PathLike], feature: str, opset: Optional[int] = None ) -> "ORTOptimizer": """ - Instantiate a `ORTOptimizer` from a pretrained pytorch model and tokenizer. + Instantiate a `ORTOptimizer` from a pretrained pytorch model and preprocessor. Args: model_name_or_path (`Union[str, os.PathLike]`): @@ -54,22 +55,23 @@ def from_pretrained( Returns: An instance of `ORTOptimizer`. """ - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + preprocessor = get_preprocessor(model_name_or_path) model_class = FeaturesManager.get_model_class_for_feature(feature) model = model_class.from_pretrained(model_name_or_path) - return ORTOptimizer(tokenizer, model, feature, opset) + + return ORTOptimizer(preprocessor, model, feature, opset) def __init__( self, - tokenizer: PreTrainedTokenizer, + preprocessor: Union[AutoFeatureExtractor, AutoProcessor, AutoTokenizer], model: PreTrainedModel, feature: str = "default", opset: Optional[int] = None, ): """ Args: - tokenizer (`PreTrainedTokenizer`): - The tokenizer used to preprocess the data. + preprocessor (`Union[AutoFeatureExtractor, AutoProcessor, AutoTokenizer]`): + The preprocessor used to preprocess the data. model (`PreTrainedModel`): The model to optimize. feature (`str`, defaults to `"default"`): @@ -79,7 +81,7 @@ def __init__( """ super().__init__() - self.tokenizer = tokenizer + self.preprocessor = preprocessor self.model = model self.feature = feature self._model_type, onnx_config_factory = FeaturesManager.check_supported_model_or_raise(model, feature=feature) @@ -117,7 +119,7 @@ def export( # Export the model if it has not already been exported to ONNX IR if not onnx_model_path.exists(): - export(self.tokenizer, self.model, self._onnx_config, self.opset, onnx_model_path) + export(self.preprocessor, self.model, self._onnx_config, self.opset, onnx_model_path) ORTConfigManager.check_supported_model_or_raise(self._model_type) num_heads = getattr(self.model.config, ORTConfigManager.get_num_heads_name(self._model_type)) diff --git a/optimum/onnxruntime/quantization.py b/optimum/onnxruntime/quantization.py index 346cab5d68..31b325a42b 100644 --- a/optimum/onnxruntime/quantization.py +++ b/optimum/onnxruntime/quantization.py @@ -20,9 +20,10 @@ from typing import Callable, Dict, List, Optional, Tuple, Union from datasets import Dataset, load_dataset -from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizer +from transformers import AutoFeatureExtractor, AutoTokenizer, PreTrainedModel from transformers.onnx import export from transformers.onnx.features import FeaturesManager +from transformers.onnx.utils import get_preprocessor import onnx from onnxruntime.quantization import CalibrationDataReader, QuantFormat, QuantizationMode, QuantType @@ -85,7 +86,7 @@ def from_pretrained( model_name_or_path: Union[str, os.PathLike], feature: str, opset: Optional[int] = None ) -> "ORTQuantizer": """ - Instantiate a `ORTQuantizer` from a pretrained pytorch model and tokenizer. + Instantiate a `ORTQuantizer` from a pretrained pytorch model and preprocessor. Args: model_name_or_path (`Union[str, os.PathLike]`): @@ -98,23 +99,23 @@ def from_pretrained( Returns: An instance of `ORTQuantizer`. """ - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + preprocessor = get_preprocessor(model_name_or_path) model_class = FeaturesManager.get_model_class_for_feature(feature) model = model_class.from_pretrained(model_name_or_path) - return ORTQuantizer(tokenizer, model, feature, opset) + return ORTQuantizer(preprocessor, model, feature, opset) def __init__( self, - tokenizer: PreTrainedTokenizer, + preprocessor: Union[AutoTokenizer, AutoFeatureExtractor], model: PreTrainedModel, feature: str = "default", opset: Optional[int] = None, ): """ Args: - tokenizer (`PreTrainedTokenizer`): - The tokenizer used to preprocess the data. + preprocessor (`Union[AutoTokenizer, AutoFeatureExtractor]`): + The preprocessor used to preprocess the data. model (`PreTrainedModel`): The model to optimize. feature (`str`, defaults to `"default"`): @@ -124,7 +125,7 @@ def __init__( """ super().__init__() - self.tokenizer = tokenizer + self.preprocessor = preprocessor self.model = model self.feature = feature @@ -236,7 +237,7 @@ def partial_fit( # Export the model to ONNX IR if not onnx_model_path.exists(): - export(self.tokenizer, self.model, self._onnx_config, self.opset, onnx_model_path) + export(self.preprocessor, self.model, self._onnx_config, self.opset, onnx_model_path) LOGGER.info(f"Exported model to ONNX at: {onnx_model_path.as_posix()}") @@ -306,7 +307,7 @@ def export( # Export the model if it has not already been exported to ONNX IR (useful for dynamic quantization) if not onnx_model_path.exists(): - export(self.tokenizer, self.model, self._onnx_config, self.opset, onnx_model_path) + export(self.preprocessor, self.model, self._onnx_config, self.opset, onnx_model_path) use_qdq = quantization_config.is_static and quantization_config.format == QuantFormat.QDQ diff --git a/optimum/onnxruntime/runs/utils.py b/optimum/onnxruntime/runs/utils.py index 2d9d20296a..fb5fb36988 100644 --- a/optimum/onnxruntime/runs/utils.py +++ b/optimum/onnxruntime/runs/utils.py @@ -1,6 +1,7 @@ from optimum.onnxruntime.modeling_ort import ( ORTModelForCausalLM, ORTModelForFeatureExtraction, + ORTModelForImageClassification, ORTModelForQuestionAnswering, ORTModelForSequenceClassification, ORTModelForTokenClassification, @@ -8,9 +9,10 @@ task_ortmodel_map = { + "causal-lm": ORTModelForCausalLM, "feature-extraction": ORTModelForFeatureExtraction, + "image-classification": ORTModelForImageClassification, "question-answering": ORTModelForQuestionAnswering, "text-classification": ORTModelForSequenceClassification, "token-classification": ORTModelForTokenClassification, - "causal-lm": ORTModelForCausalLM, } diff --git a/optimum/pipelines.py b/optimum/pipelines.py index d68cabacff..837fa5c560 100644 --- a/optimum/pipelines.py +++ b/optimum/pipelines.py @@ -1,8 +1,8 @@ from typing import Any, Optional, Union from transformers import ( - AutoTokenizer, FeatureExtractionPipeline, + ImageClassificationPipeline, Pipeline, PreTrainedTokenizer, QuestionAnsweringPipeline, @@ -13,6 +13,9 @@ ) from transformers import pipeline as transformers_pipeline from transformers.feature_extraction_utils import PreTrainedFeatureExtractor +from transformers.models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING +from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING +from transformers.onnx.utils import get_preprocessor from optimum.utils import is_onnxruntime_available @@ -23,6 +26,7 @@ from optimum.onnxruntime import ( ORTModelForCausalLM, ORTModelForFeatureExtraction, + ORTModelForImageClassification, ORTModelForQuestionAnswering, ORTModelForSequenceClassification, ORTModelForTokenClassification, @@ -35,31 +39,36 @@ "class": (ORTModelForFeatureExtraction,) if is_onnxruntime_available() else (), "default": "distilbert-base-cased", }, + "image-classification": { + "impl": ImageClassificationPipeline, + "class": (ORTModelForImageClassification,) if is_onnxruntime_available() else (), + "default": "google/vit-base-patch16-224", + }, + "question-answering": { + "impl": QuestionAnsweringPipeline, + "class": (ORTModelForQuestionAnswering,) if is_onnxruntime_available() else (), + "default": "distilbert-base-cased-distilled-squad", + }, "text-classification": { "impl": TextClassificationPipeline, "class": (ORTModelForSequenceClassification,) if is_onnxruntime_available() else (), "default": "distilbert-base-uncased-finetuned-sst-2-english", }, + "text-generation": { + "impl": TextGenerationPipeline, + "class": (ORTModelForCausalLM,) if is_onnxruntime_available() else (), + "default": "distilgpt2", + }, "token-classification": { "impl": TokenClassificationPipeline, "class": (ORTModelForTokenClassification,) if is_onnxruntime_available() else (), "default": "dbmdz/bert-large-cased-finetuned-conll03-english", }, - "question-answering": { - "impl": QuestionAnsweringPipeline, - "class": (ORTModelForQuestionAnswering,) if is_onnxruntime_available() else (), - "default": "distilbert-base-cased-distilled-squad", - }, "zero-shot-classification": { "impl": ZeroShotClassificationPipeline, "class": (ORTModelForSequenceClassification,) if is_onnxruntime_available() else (), "default": "facebook/bart-large-mnli", }, - "text-generation": { - "impl": TextGenerationPipeline, - "class": (ORTModelForCausalLM,) if is_onnxruntime_available() else (), - "default": "distilgpt2", - }, } @@ -80,6 +89,9 @@ def pipeline( if accelerator != "ort": raise ValueError(f"Accelerator {accelerator} is not supported. Supported accelerators are ort") + load_tokenizer = type(model.config) in TOKENIZER_MAPPING or model.config.tokenizer_class is not None + load_feature_extractor = type(model.config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None + if model is None: model_id = SUPPORTED_TASKS[task]["default"] model = SUPPORTED_TASKS[task]["class"][0].from_pretrained(model_id, from_transformers=True) @@ -87,20 +99,27 @@ def pipeline( model_id = model model = SUPPORTED_TASKS[task]["class"][0].from_pretrained(model, from_transformers=True) elif isinstance(model, ORTModel): - if tokenizer is None: + if tokenizer is None and load_tokenizer: raise ValueError("If you pass a model as a ORTModel, you must pass a tokenizer as well") + if feature_extractor is None and load_feature_extractor: + raise ValueError("If you pass a model as a ORTModel, you must pass a feature extractor as well") else: raise ValueError( f"""Model {model} is not supported. Please provide a valid model either as string or ORTModel. You can also provide non model then a default one will be used""" ) - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(model_id) + + if tokenizer is None and load_tokenizer: + tokenizer = get_preprocessor(model_id) + if feature_extractor is None and load_feature_extractor: + feature_extractor = get_preprocessor(model_id) return transformers_pipeline( task, model=model, tokenizer=tokenizer, + feature_extractor=feature_extractor, use_fast=use_fast, + use_auth_token=use_auth_token, **kwargs, ) diff --git a/setup.py b/setup.py index 7a45388bb7..54f4eeee1e 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ "huggingface_hub>=0.4.0", ] -TESTS_REQUIRE = ["pytest", "requests", "parameterized", "pytest-xdist"] +TESTS_REQUIRE = ["pytest", "requests", "parameterized", "pytest-xdist", "Pillow"] QUALITY_REQUIRE = ["black~=22.0", "flake8>=3.8.3", "isort>=5.5.4"] diff --git a/tests/onnxruntime/test_modeling_ort.py b/tests/onnxruntime/test_modeling_ort.py index 5a0a100336..4fab52d4c3 100644 --- a/tests/onnxruntime/test_modeling_ort.py +++ b/tests/onnxruntime/test_modeling_ort.py @@ -5,24 +5,28 @@ from pathlib import Path import torch +from PIL import Image from transformers import ( AutoModel, AutoModelForCausalLM, + AutoModelForImageClassification, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForTokenClassification, - AutoTokenizer, PretrainedConfig, pipeline, ) +from transformers.onnx.utils import get_preprocessor from transformers.testing_utils import require_torch_gpu import onnxruntime +import requests from huggingface_hub.utils import EntryNotFoundError from optimum.onnxruntime import ( ONNX_WEIGHTS_NAME, ORTModelForCausalLM, ORTModelForFeatureExtraction, + ORTModelForImageClassification, ORTModelForQuestionAnswering, ORTModelForSequenceClassification, ORTModelForTokenClassification, @@ -135,15 +139,15 @@ def test_supported_transformers_architectures(self, *args, **kwargs): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - model = ORTModelForQuestionAnswering.from_pretrained("t5-small") + model = ORTModelForQuestionAnswering.from_pretrained("t5-small", from_transformers=True) - self.assertTrue("Unrecognized configuration class", context.exception) + self.assertIn("Unrecognized configuration class", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) def test_model_call(self, *args, **kwargs): model_arch, model_id = args model = ORTModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -160,7 +164,7 @@ def test_compare_to_transformers(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) trfs_model = AutoModelForQuestionAnswering.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -177,7 +181,7 @@ def test_compare_to_transformers(self, *args, **kwargs): def test_pipeline(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("question-answering", model=onnx_model, tokenizer=tokenizer) question = "Whats my name?" context = "My Name is Philipp and I live in Nuremberg." @@ -192,7 +196,7 @@ def test_pipeline(self, *args, **kwargs): def test_pipeline_on_gpu(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("question-answering", model=onnx_model, tokenizer=tokenizer, device=0) question = "Whats my name?" context = "My Name is Philipp and I live in Nuremberg." @@ -207,7 +211,7 @@ def test_pipeline_on_gpu(self, *args, **kwargs): def test_default_pipeline_and_model_device(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForQuestionAnswering.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("question-answering", model=onnx_model, tokenizer=tokenizer) self.assertEqual(pp.device, pp.model.device) @@ -238,15 +242,15 @@ def test_supported_transformers_architectures(self, *args, **kwargs): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - model = ORTModelForSequenceClassification.from_pretrained("t5-small", from_transformers=Tru) + model = ORTModelForSequenceClassification.from_pretrained("t5-small", from_transformers=True) - self.assertTrue("Unrecognized configuration class", context.exception) + self.assertIn("Unrecognized configuration class", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) def test_model_forward_call(self, *args, **kwargs): model_arch, model_id = args model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -260,7 +264,7 @@ def test_compare_to_transformers(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) trfs_model = AutoModelForSequenceClassification.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -276,7 +280,7 @@ def test_compare_to_transformers(self, *args, **kwargs): def test_pipeline(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("text-classification", model=onnx_model, tokenizer=tokenizer) text = "My Name is Philipp and i live in Germany." outputs = pp(text) @@ -290,7 +294,7 @@ def test_pipeline(self, *args, **kwargs): def test_pipeline_on_gpu(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("text-classification", model=onnx_model, tokenizer=tokenizer, device=0) text = "My Name is Philipp and i live in Germany." outputs = pp(text) @@ -304,7 +308,7 @@ def test_pipeline_on_gpu(self, *args, **kwargs): def test_default_pipeline_and_model_device(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("text-classification", model=onnx_model, tokenizer=tokenizer) self.assertEqual(pp.device, onnx_model.device) @@ -312,7 +316,7 @@ def test_pipeline_zero_shot_classification(self): onnx_model = ORTModelForSequenceClassification.from_pretrained( "typeform/distilbert-base-uncased-mnli", from_transformers=True ) - tokenizer = AutoTokenizer.from_pretrained("typeform/distilbert-base-uncased-mnli") + tokenizer = get_preprocessor("typeform/distilbert-base-uncased-mnli") pp = pipeline("zero-shot-classification", model=onnx_model, tokenizer=tokenizer) sequence_to_classify = "Who are you voting for in 2020?" candidate_labels = ["Europe", "public health", "politics", "elections"] @@ -348,15 +352,15 @@ def test_supported_transformers_architectures(self, *args, **kwargs): def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: - model = ORTModelForTokenClassification.from_pretrained("t5-small", from_transformers=Tru) + model = ORTModelForTokenClassification.from_pretrained("t5-small", from_transformers=True) - self.assertTrue("Unrecognized configuration class", context.exception) + self.assertIn("Unrecognized configuration class", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) def test_model_call(self, *args, **kwargs): model_arch, model_id = args model = ORTModelForTokenClassification.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -370,7 +374,7 @@ def test_compare_to_transformers(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForTokenClassification.from_pretrained(model_id, from_transformers=True) trfs_model = AutoModelForTokenClassification.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -386,7 +390,7 @@ def test_compare_to_transformers(self, *args, **kwargs): def test_pipeline(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForTokenClassification.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("token-classification", model=onnx_model, tokenizer=tokenizer) text = "My Name is Philipp and i live in Germany." outputs = pp(text) @@ -399,7 +403,7 @@ def test_pipeline(self, *args, **kwargs): def test_pipeline_on_gpu(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForTokenClassification.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("token-classification", model=onnx_model, tokenizer=tokenizer, device=0) text = "My Name is Philipp and i live in Germany." outputs = pp(text) @@ -412,7 +416,7 @@ def test_pipeline_on_gpu(self, *args, **kwargs): def test_default_pipeline_and_model_device(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForTokenClassification.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("token-classification", model=onnx_model, tokenizer=tokenizer) self.assertEqual(pp.device, onnx_model.device) @@ -439,17 +443,11 @@ def test_supported_transformers_architectures(self, *args, **kwargs): self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) - def test_load_vanilla_transformers_which_is_not_supported(self): - with self.assertRaises(Exception) as context: - model = ORTModelForFeatureExtraction.from_pretrained("google/vit-base-patch16-224", from_transformers=Tru) - - self.assertTrue("Unrecognized configuration class", context.exception) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) def test_model_call(self, *args, **kwargs): model_arch, model_id = args model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -463,7 +461,7 @@ def test_compare_to_transformers(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) trfs_model = AutoModel.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -479,7 +477,7 @@ def test_compare_to_transformers(self, *args, **kwargs): def test_pipeline(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("feature-extraction", model=onnx_model, tokenizer=tokenizer) text = "My Name is Philipp and i live in Germany." outputs = pp(text) @@ -492,7 +490,7 @@ def test_pipeline(self, *args, **kwargs): def test_pipeline_on_gpu(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("feature-extraction", model=onnx_model, tokenizer=tokenizer, device=0) text = "My Name is Philipp and i live in Germany." outputs = pp(text) @@ -505,7 +503,7 @@ def test_pipeline_on_gpu(self, *args, **kwargs): def test_default_pipeline_and_model_device(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForFeatureExtraction.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("feature-extraction", model=onnx_model, tokenizer=tokenizer) self.assertEqual(pp.device, onnx_model.device) @@ -527,13 +525,13 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: model = ORTModelForCausalLM.from_pretrained("google/vit-base-patch16-224", from_transformers=True) - self.assertTrue("Unrecognized configuration class", context.exception) + self.assertIn("Unrecognized configuration class", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) def test_model_call(self, *args, **kwargs): model_arch, model_id = args model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -546,7 +544,7 @@ def test_model_call(self, *args, **kwargs): def test_generate_utils(self, *args, **kwargs): model_arch, model_id = args model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) text = "This is a sample output" tokens = tokenizer( text, @@ -561,7 +559,7 @@ def test_generate_utils(self, *args, **kwargs): def test_generate_utils_with_input_ids(self, *args, **kwargs): model_arch, model_id = args model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) text = "This is a sample output" tokens = tokenizer( text, @@ -577,7 +575,7 @@ def test_compare_to_transformers(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) trfs_model = AutoModelForCausalLM.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) tokens = tokenizer( "This is a sample output", return_tensors="pt", @@ -593,7 +591,7 @@ def test_compare_to_transformers(self, *args, **kwargs): def test_pipeline(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer) text = "My Name is Philipp and i live" outputs = pp(text) @@ -607,7 +605,7 @@ def test_pipeline(self, *args, **kwargs): def test_pipeline_on_gpu(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer) text = "My Name is Philipp and i live" outputs = pp(text) @@ -621,6 +619,90 @@ def test_pipeline_on_gpu(self, *args, **kwargs): def test_default_pipeline_and_model_device(self, *args, **kwargs): model_arch, model_id = args onnx_model = ORTModelForCausalLM.from_pretrained(model_id, from_transformers=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = get_preprocessor(model_id) pp = pipeline("text-generation", model=onnx_model, tokenizer=tokenizer) self.assertEqual(pp.device, onnx_model.device) + + +class ORTModelForImageClassificationIntergrationTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = { + "vit": "hf-internal-testing/tiny-random-vit", + } + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_supported_transformers_architectures(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForImageClassification.from_pretrained(model_id, from_transformers=True) + self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.config, PretrainedConfig) + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + model = ORTModelForImageClassification.from_pretrained("t5-small", from_transformers=True) + + self.assertIn("Unrecognized configuration class", str(context.exception)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_model_forward_call(self, *args, **kwargs): + model_arch, model_id = args + model = ORTModelForImageClassification.from_pretrained(model_id, from_transformers=True) + preprocessor = get_preprocessor(model_id) + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + inputs = preprocessor(images=image, return_tensors="pt") + outputs = model(**inputs) + self.assertTrue("logits" in outputs) + self.assertTrue(isinstance(outputs.logits, torch.Tensor)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_compare_to_transformers(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForImageClassification.from_pretrained(model_id, from_transformers=True) + trfs_model = AutoModelForImageClassification.from_pretrained(model_id) + preprocessor = get_preprocessor(model_id) + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + inputs = preprocessor(images=image, return_tensors="pt") + with torch.no_grad(): + trtfs_outputs = trfs_model(**inputs) + onnx_outputs = onnx_model(**inputs) + + # compare tensor outputs + self.assertTrue(torch.allclose(onnx_outputs.logits, trtfs_outputs.logits, atol=1e-4)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_pipeline(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForImageClassification.from_pretrained(model_id, from_transformers=True) + preprocessor = get_preprocessor(model_id) + pp = pipeline("image-classification", model=onnx_model, feature_extractor=preprocessor) + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + outputs = pp(url) + + # compare model output class + self.assertGreaterEqual(outputs[0]["score"], 0.0) + self.assertTrue(isinstance(outputs[0]["label"], str)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + @require_torch_gpu + def test_pipeline_on_gpu(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForImageClassification.from_pretrained(model_id, from_transformers=True) + preprocessor = get_preprocessor(model_id) + pp = pipeline("image-classification", model=onnx_model, feature_extractor=preprocessor) + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + outputs = pp(url) + # check model device + self.assertEqual(pp.model.device.type.lower(), "cuda") + + # compare model output class + self.assertGreaterEqual(outputs[0]["score"], 0.0) + self.assertTrue(isinstance(outputs[0]["label"], str)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_MODEL_ID.items()) + def test_default_pipeline_and_model_device(self, *args, **kwargs): + model_arch, model_id = args + onnx_model = ORTModelForImageClassification.from_pretrained(model_id, from_transformers=True) + preprocessor = get_preprocessor(model_id) + pp = pipeline("image-classification", model=onnx_model, feature_extractor=preprocessor) + self.assertEqual(pp.device, onnx_model.device) diff --git a/tests/onnxruntime/test_onnxruntime.py b/tests/onnxruntime/test_onnxruntime.py index 4d3c6cc58f..3f20df1897 100644 --- a/tests/onnxruntime/test_onnxruntime.py +++ b/tests/onnxruntime/test_onnxruntime.py @@ -13,17 +13,15 @@ # limitations under the License. import gc -import os import tempfile import unittest from functools import partial from pathlib import Path -from transformers import AutoTokenizer from transformers.onnx import validate_model_outputs from onnxruntime.quantization import QuantFormat, QuantizationMode, QuantType -from optimum.onnxruntime import ORTConfig, ORTOptimizer, ORTQuantizableOperator, ORTQuantizer +from optimum.onnxruntime import ORTConfig, ORTOptimizer, ORTQuantizer from optimum.onnxruntime.configuration import ( AutoCalibrationConfig, AutoQuantizationConfig, @@ -68,7 +66,7 @@ def test_optimize(self): ) validate_model_outputs( optimizer._onnx_config, - optimizer.tokenizer, + optimizer.preprocessor, optimizer.model, optimized_model_path, list(optimizer._onnx_config.outputs.keys()), @@ -128,7 +126,7 @@ def test_dynamic_quantization(self): ) validate_model_outputs( quantizer._onnx_config, - quantizer.tokenizer, + quantizer.preprocessor, quantizer.model, q8_model_path, list(quantizer._onnx_config.outputs.keys()), @@ -163,7 +161,7 @@ def preprocess_function(examples, tokenizer): calibration_dataset = quantizer.get_calibration_dataset( "glue", dataset_config_name="sst2", - preprocess_function=partial(preprocess_function, tokenizer=quantizer.tokenizer), + preprocess_function=partial(preprocess_function, tokenizer=quantizer.preprocessor), num_samples=40, dataset_split="train", ) @@ -181,7 +179,7 @@ def preprocess_function(examples, tokenizer): ) validate_model_outputs( quantizer._onnx_config, - quantizer.tokenizer, + quantizer.preprocessor, quantizer.model, q8_model_path, list(quantizer._onnx_config.outputs.keys()),