|
30 | 30 | from huggingface_hub import HfApi, hf_hub_download
|
31 | 31 |
|
32 | 32 | from ..modeling_base import OptimizedModel
|
33 |
| -from .utils import ONNX_WEIGHTS_NAME, _is_gpu_available |
| 33 | +from .utils import ONNX_WEIGHTS_NAME, get_device_for_provider, get_provider_for_device |
34 | 34 |
|
35 | 35 |
|
36 | 36 | logger = logging.getLogger(__name__)
|
@@ -85,28 +85,50 @@ def __init__(self, model=None, config=None, **kwargs):
|
85 | 85 | self.config = config
|
86 | 86 | self.model_save_dir = kwargs.get("model_save_dir", None)
|
87 | 87 | self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
|
| 88 | + self._device = get_device_for_provider(self.model.get_providers()[0]) |
88 | 89 |
|
89 | 90 | # registers the ORTModelForXXX classes into the transformers AutoModel classes
|
90 | 91 | # to avoid warnings when create a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
|
91 | 92 | AutoConfig.register(self.base_model_prefix, AutoConfig)
|
92 | 93 | self.auto_model_class.register(AutoConfig, self.__class__)
|
93 | 94 |
|
| 95 | + @property |
| 96 | + def device(self) -> torch.device: |
| 97 | + """ |
| 98 | + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same |
| 99 | + device). |
| 100 | + """ |
| 101 | + return self._device |
| 102 | + |
| 103 | + @device.setter |
| 104 | + def device(self, value): |
| 105 | + self._device = value |
| 106 | + |
| 107 | + def to(self, device): |
| 108 | + """ |
| 109 | + Changes the ONNX Runtime provider according to the device. |
| 110 | + """ |
| 111 | + self.device = device |
| 112 | + provider = get_provider_for_device(self.device) |
| 113 | + self.model.set_providers([provider]) |
| 114 | + return self |
| 115 | + |
94 | 116 | def forward(self, *args, **kwargs):
|
95 | 117 | raise NotImplementedError
|
96 | 118 |
|
97 | 119 | @staticmethod
|
98 | 120 | def load_model(path: Union[str, Path], provider=None):
|
99 | 121 | """
|
100 |
| - loads ONNX Inference session with Provider. Default Provider is if CUDAExecutionProvider GPU available else `CPUExecutionProvider` |
| 122 | + Loads an ONNX Inference session with a given provider. Default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX. |
| 123 | +
|
101 | 124 | Arguments:
|
102 | 125 | path (`str` or `Path`):
|
103 |
| - Directory from which to load |
| 126 | + Directory from which to load the model. |
104 | 127 | provider(`str`, *optional*):
|
105 |
| - Onnxruntime provider to use for loading the model, defaults to `CUDAExecutionProvider` if GPU is |
106 |
| - available else `CPUExecutionProvider` |
| 128 | + ONNX Runtime provider to use for loading the model. Defaults to `CPUExecutionProvider`. |
107 | 129 | """
|
108 | 130 | if provider is None:
|
109 |
| - provider = "CUDAExecutionProvider" if _is_gpu_available() else "CPUExecutionProvider" |
| 131 | + provider = "CPUExecutionProvider" |
110 | 132 |
|
111 | 133 | return ort.InferenceSession(path, providers=[provider])
|
112 | 134 |
|
@@ -330,10 +352,9 @@ def forward(
|
330 | 352 | onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy()
|
331 | 353 | # run inference
|
332 | 354 | outputs = self.model.run(None, onnx_inputs)
|
| 355 | + last_hidden_state = torch.from_numpy(outputs[self.model_outputs["last_hidden_state"]]).to(self.device) |
333 | 356 | # converts output to namedtuple for pipelines post-processing
|
334 |
| - return BaseModelOutput( |
335 |
| - last_hidden_state=torch.from_numpy(outputs[self.model_outputs["last_hidden_state"]]), |
336 |
| - ) |
| 357 | + return BaseModelOutput(last_hidden_state=last_hidden_state) |
337 | 358 |
|
338 | 359 |
|
339 | 360 | QUESTION_ANSWERING_SAMPLE = r"""
|
@@ -416,10 +437,12 @@ def forward(
|
416 | 437 | onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy()
|
417 | 438 | # run inference
|
418 | 439 | outputs = self.model.run(None, onnx_inputs)
|
| 440 | + start_logits = torch.from_numpy(outputs[self.model_outputs["start_logits"]]).to(self.device) |
| 441 | + end_logits = torch.from_numpy(outputs[self.model_outputs["end_logits"]]).to(self.device) |
419 | 442 | # converts output to namedtuple for pipelines post-processing
|
420 | 443 | return QuestionAnsweringModelOutput(
|
421 |
| - start_logits=torch.from_numpy(outputs[self.model_outputs["start_logits"]]), |
422 |
| - end_logits=torch.from_numpy(outputs[self.model_outputs["end_logits"]]), |
| 444 | + start_logits=start_logits, |
| 445 | + end_logits=end_logits, |
423 | 446 | )
|
424 | 447 |
|
425 | 448 |
|
@@ -519,9 +542,10 @@ def forward(
|
519 | 542 | onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy()
|
520 | 543 | # run inference
|
521 | 544 | outputs = self.model.run(None, onnx_inputs)
|
| 545 | + logits = torch.from_numpy(outputs[self.model_outputs["logits"]]).to(self.device) |
522 | 546 | # converts output to namedtuple for pipelines post-processing
|
523 | 547 | return SequenceClassifierOutput(
|
524 |
| - logits=torch.from_numpy(outputs[self.model_outputs["logits"]]), |
| 548 | + logits=logits, |
525 | 549 | )
|
526 | 550 |
|
527 | 551 |
|
@@ -604,9 +628,10 @@ def forward(
|
604 | 628 | onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy()
|
605 | 629 | # run inference
|
606 | 630 | outputs = self.model.run(None, onnx_inputs)
|
| 631 | + logits = torch.from_numpy(outputs[self.model_outputs["logits"]]).to(self.device) |
607 | 632 | # converts output to namedtuple for pipelines post-processing
|
608 | 633 | return TokenClassifierOutput(
|
609 |
| - logits=torch.from_numpy(outputs[self.model_outputs["logits"]]), |
| 634 | + logits=logits, |
610 | 635 | )
|
611 | 636 |
|
612 | 637 |
|
@@ -665,14 +690,6 @@ def __init__(self, *args, **kwargs):
|
665 | 690 | self.main_input_name = "input_ids"
|
666 | 691 | self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())}
|
667 | 692 |
|
668 |
| - @property |
669 |
| - def device(self) -> torch.device: |
670 |
| - """ |
671 |
| - `torch.device`: The device on which the module is (assuming that all the module parameters are on the same |
672 |
| - device). |
673 |
| - """ |
674 |
| - return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
675 |
| - |
676 | 693 | def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
|
677 | 694 | """
|
678 | 695 | Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
|
@@ -703,9 +720,10 @@ def forward(
|
703 | 720 | }
|
704 | 721 | # run inference
|
705 | 722 | outputs = self.model.run(None, onnx_inputs)
|
| 723 | + logits = torch.from_numpy(outputs[self.model_outputs["logits"]]).to(self.device) |
706 | 724 | # converts output to namedtuple for pipelines post-processing
|
707 | 725 | return CausalLMOutputWithCrossAttentions(
|
708 |
| - logits=torch.from_numpy(outputs[self.model_outputs["logits"]]), |
| 726 | + logits=logits, |
709 | 727 | )
|
710 | 728 |
|
711 | 729 | # Adapted from https://github.com/huggingface/transformers/blob/99289c08a1b16a805dd4ee46de029e9fd23cba3d/src/transformers/generation_utils.py#L490
|
|
0 commit comments