Skip to content

Commit 14289c0

Browse files
authored
Align device and ORT provider in ORT models (#203)
1 parent 8f50efd commit 14289c0

File tree

5 files changed

+288
-28
lines changed

5 files changed

+288
-28
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
name: ONNX Runtime / Test GPU
2+
3+
on:
4+
workflow_dispatch:
5+
schedule:
6+
- cron: 0 7 * * * # every day at 7am
7+
8+
jobs:
9+
start-runner:
10+
name: Start self-hosted EC2 runner
11+
runs-on: ubuntu-latest
12+
env:
13+
AWS_REGION: us-east-1
14+
EC2_AMI_ID: ami-0dc1c26161f869ed1
15+
EC2_INSTANCE_TYPE: g4dn.xlarge
16+
EC2_SUBNET_ID: subnet-859322b4,subnet-b7533b96,subnet-47cfad21,subnet-a396b2ad,subnet-06576a4b,subnet-df0f6180
17+
EC2_SECURITY_GROUP: sg-0bb210cd3ec725a13
18+
EC2_IAM_ROLE: optimum-ec2-github-actions-role
19+
outputs:
20+
label: ${{ steps.start-ec2-runner.outputs.label }}
21+
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
22+
steps:
23+
- name: Configure AWS credentials
24+
uses: aws-actions/configure-aws-credentials@v1
25+
with:
26+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
27+
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
28+
aws-region: ${{ env.AWS_REGION }}
29+
- name: Start EC2 runner
30+
id: start-ec2-runner
31+
uses: philschmid/philschmid-ec2-github-runner@main
32+
with:
33+
mode: start
34+
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
35+
ec2-image-id: ${{ env.EC2_AMI_ID }}
36+
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
37+
subnet-id: ${{ env.EC2_SUBNET_ID }}
38+
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
39+
iam-role-name: ${{ env.EC2_IAM_ROLE }}
40+
aws-resource-tags: > # optional, requires additional permissions
41+
[
42+
{"Key": "Name", "Value": "ec2-optimum-github-runner"},
43+
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
44+
]
45+
do-the-job:
46+
name: Setup
47+
needs: start-runner # required to start the main job when the runner is ready
48+
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
49+
env:
50+
AWS_REGION: us-east-1
51+
steps:
52+
- name: Checkout
53+
uses: actions/checkout@v2
54+
- name: Install dependencies
55+
run: |
56+
sudo apt -y update && sudo pip install --upgrade pip
57+
pip install .[onnxruntime-gpu,tests]
58+
- name: Test with unittest
59+
working-directory: tests
60+
run: |
61+
python -m unittest discover -s onnxruntime -p 'test_*.py'
62+
63+
stop-runner:
64+
name: Stop self-hosted EC2 runner
65+
needs:
66+
- start-runner # required to get output from the start-runner job
67+
- do-the-job # required to wait when the main job is done
68+
runs-on: ubuntu-latest
69+
env:
70+
AWS_REGION: us-east-1
71+
if: ${{ always() }} # required to stop the runner even if the error happened in the previous jobs
72+
steps:
73+
- name: Configure AWS credentials
74+
uses: aws-actions/configure-aws-credentials@v1
75+
with:
76+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
77+
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
78+
aws-region: ${{ env.AWS_REGION }}
79+
- name: Stop EC2 runner
80+
uses: philschmid/philschmid-ec2-github-runner@main
81+
with:
82+
mode: stop
83+
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
84+
label: ${{ needs.start-runner.outputs.label }}
85+
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

optimum/onnxruntime/modeling_ort.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from huggingface_hub import HfApi, hf_hub_download
3131

3232
from ..modeling_base import OptimizedModel
33-
from .utils import ONNX_WEIGHTS_NAME, _is_gpu_available
33+
from .utils import ONNX_WEIGHTS_NAME, get_device_for_provider, get_provider_for_device
3434

3535

3636
logger = logging.getLogger(__name__)
@@ -85,28 +85,50 @@ def __init__(self, model=None, config=None, **kwargs):
8585
self.config = config
8686
self.model_save_dir = kwargs.get("model_save_dir", None)
8787
self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
88+
self._device = get_device_for_provider(self.model.get_providers()[0])
8889

8990
# registers the ORTModelForXXX classes into the transformers AutoModel classes
9091
# to avoid warnings when create a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
9192
AutoConfig.register(self.base_model_prefix, AutoConfig)
9293
self.auto_model_class.register(AutoConfig, self.__class__)
9394

95+
@property
96+
def device(self) -> torch.device:
97+
"""
98+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
99+
device).
100+
"""
101+
return self._device
102+
103+
@device.setter
104+
def device(self, value):
105+
self._device = value
106+
107+
def to(self, device):
108+
"""
109+
Changes the ONNX Runtime provider according to the device.
110+
"""
111+
self.device = device
112+
provider = get_provider_for_device(self.device)
113+
self.model.set_providers([provider])
114+
return self
115+
94116
def forward(self, *args, **kwargs):
95117
raise NotImplementedError
96118

97119
@staticmethod
98120
def load_model(path: Union[str, Path], provider=None):
99121
"""
100-
loads ONNX Inference session with Provider. Default Provider is if CUDAExecutionProvider GPU available else `CPUExecutionProvider`
122+
Loads an ONNX Inference session with a given provider. Default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX.
123+
101124
Arguments:
102125
path (`str` or `Path`):
103-
Directory from which to load
126+
Directory from which to load the model.
104127
provider(`str`, *optional*):
105-
Onnxruntime provider to use for loading the model, defaults to `CUDAExecutionProvider` if GPU is
106-
available else `CPUExecutionProvider`
128+
ONNX Runtime provider to use for loading the model. Defaults to `CPUExecutionProvider`.
107129
"""
108130
if provider is None:
109-
provider = "CUDAExecutionProvider" if _is_gpu_available() else "CPUExecutionProvider"
131+
provider = "CPUExecutionProvider"
110132

111133
return ort.InferenceSession(path, providers=[provider])
112134

@@ -330,10 +352,9 @@ def forward(
330352
onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy()
331353
# run inference
332354
outputs = self.model.run(None, onnx_inputs)
355+
last_hidden_state = torch.from_numpy(outputs[self.model_outputs["last_hidden_state"]]).to(self.device)
333356
# converts output to namedtuple for pipelines post-processing
334-
return BaseModelOutput(
335-
last_hidden_state=torch.from_numpy(outputs[self.model_outputs["last_hidden_state"]]),
336-
)
357+
return BaseModelOutput(last_hidden_state=last_hidden_state)
337358

338359

339360
QUESTION_ANSWERING_SAMPLE = r"""
@@ -416,10 +437,12 @@ def forward(
416437
onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy()
417438
# run inference
418439
outputs = self.model.run(None, onnx_inputs)
440+
start_logits = torch.from_numpy(outputs[self.model_outputs["start_logits"]]).to(self.device)
441+
end_logits = torch.from_numpy(outputs[self.model_outputs["end_logits"]]).to(self.device)
419442
# converts output to namedtuple for pipelines post-processing
420443
return QuestionAnsweringModelOutput(
421-
start_logits=torch.from_numpy(outputs[self.model_outputs["start_logits"]]),
422-
end_logits=torch.from_numpy(outputs[self.model_outputs["end_logits"]]),
444+
start_logits=start_logits,
445+
end_logits=end_logits,
423446
)
424447

425448

@@ -519,9 +542,10 @@ def forward(
519542
onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy()
520543
# run inference
521544
outputs = self.model.run(None, onnx_inputs)
545+
logits = torch.from_numpy(outputs[self.model_outputs["logits"]]).to(self.device)
522546
# converts output to namedtuple for pipelines post-processing
523547
return SequenceClassifierOutput(
524-
logits=torch.from_numpy(outputs[self.model_outputs["logits"]]),
548+
logits=logits,
525549
)
526550

527551

@@ -604,9 +628,10 @@ def forward(
604628
onnx_inputs["token_type_ids"] = token_type_ids.cpu().detach().numpy()
605629
# run inference
606630
outputs = self.model.run(None, onnx_inputs)
631+
logits = torch.from_numpy(outputs[self.model_outputs["logits"]]).to(self.device)
607632
# converts output to namedtuple for pipelines post-processing
608633
return TokenClassifierOutput(
609-
logits=torch.from_numpy(outputs[self.model_outputs["logits"]]),
634+
logits=logits,
610635
)
611636

612637

@@ -665,14 +690,6 @@ def __init__(self, *args, **kwargs):
665690
self.main_input_name = "input_ids"
666691
self.model_outputs = {output_key.name: idx for idx, output_key in enumerate(self.model.get_outputs())}
667692

668-
@property
669-
def device(self) -> torch.device:
670-
"""
671-
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
672-
device).
673-
"""
674-
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
675-
676693
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
677694
"""
678695
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the generate method.
@@ -703,9 +720,10 @@ def forward(
703720
}
704721
# run inference
705722
outputs = self.model.run(None, onnx_inputs)
723+
logits = torch.from_numpy(outputs[self.model_outputs["logits"]]).to(self.device)
706724
# converts output to namedtuple for pipelines post-processing
707725
return CausalLMOutputWithCrossAttentions(
708-
logits=torch.from_numpy(outputs[self.model_outputs["logits"]]),
726+
logits=logits,
709727
)
710728

711729
# Adapted from https://github.com/huggingface/transformers/blob/99289c08a1b16a805dd4ee46de029e9fd23cba3d/src/transformers/generation_utils.py#L490

optimum/onnxruntime/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,17 @@ def wrap_onnx_config_for_loss(onnx_config: OnnxConfig) -> OnnxConfig:
141141
return OnnxConfigWithPastAndLoss(onnx_config)
142142
else:
143143
return OnnxConfigWithLoss(onnx_config)
144+
145+
146+
def get_device_for_provider(provider: str) -> torch.device:
147+
"""
148+
Gets the PyTorch device (CPU/CUDA) associated with an ONNX Runtime provider.
149+
"""
150+
return torch.device("cuda") if provider == "CUDAExecutionProvider" else torch.device("cpu")
151+
152+
153+
def get_provider_for_device(device: torch.device) -> str:
154+
"""
155+
Gets the ONNX Runtime provider associated with the PyTorch device (CPU/CUDA).
156+
"""
157+
return "CUDAExecutionProvider" if device.type.lower() == "cuda" else "CPUExecutionProvider"

0 commit comments

Comments
 (0)