Skip to content

Commit f62cb83

Browse files
unographysgugger
andauthored
Adds CLIP to models exportable with ONNX (#18515)
* onnx config for clip * default opset as 14 * changes from the original repo * input values order fix * outputs fix * remove unused import * ran make fix-copies * black format * review comments: forward ref, import fix, model change revert, .to cleanup * make style * formatting fixes * revert groupvit * comment for cast to int32 * comment fix * make .T as .t() for onnx conversion * ran make fix-copies * remove unneeded comment Co-authored-by: Sylvain Gugger <[email protected]> * fix copies * remove comment Co-authored-by: Sylvain Gugger <[email protected]>
1 parent 50949fa commit f62cb83

File tree

9 files changed

+82
-10
lines changed

9 files changed

+82
-10
lines changed

docs/source/en/serialization.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Ready-made configurations include the following architectures:
5555
- BlenderbotSmall
5656
- BLOOM
5757
- CamemBERT
58+
- CLIP
5859
- CodeGen
5960
- ConvBERT
6061
- ConvNeXT

src/transformers/models/clip/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929

3030

3131
_import_structure = {
32-
"configuration_clip": ["CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP", "CLIPConfig", "CLIPTextConfig", "CLIPVisionConfig"],
32+
"configuration_clip": [
33+
"CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP",
34+
"CLIPConfig",
35+
"CLIPOnnxConfig",
36+
"CLIPTextConfig",
37+
"CLIPVisionConfig",
38+
],
3339
"tokenization_clip": ["CLIPTokenizer"],
3440
}
3541

@@ -95,7 +101,13 @@
95101

96102

97103
if TYPE_CHECKING:
98-
from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig
104+
from .configuration_clip import (
105+
CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP,
106+
CLIPConfig,
107+
CLIPOnnxConfig,
108+
CLIPTextConfig,
109+
CLIPVisionConfig,
110+
)
99111
from .tokenization_clip import CLIPTokenizer
100112

101113
try:

src/transformers/models/clip/configuration_clip.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,16 @@
1616

1717
import copy
1818
import os
19-
from typing import Union
19+
from collections import OrderedDict
20+
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
21+
22+
23+
if TYPE_CHECKING:
24+
from ...processing_utils import ProcessorMixin
25+
from ...utils import TensorType
2026

2127
from ...configuration_utils import PretrainedConfig
28+
from ...onnx import OnnxConfig
2229
from ...utils import logging
2330

2431

@@ -317,3 +324,44 @@ def to_dict(self):
317324
output["vision_config"] = self.vision_config.to_dict()
318325
output["model_type"] = self.__class__.model_type
319326
return output
327+
328+
329+
class CLIPOnnxConfig(OnnxConfig):
330+
@property
331+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
332+
return OrderedDict(
333+
[
334+
("input_ids", {0: "batch", 1: "sequence"}),
335+
("pixel_values", {0: "batch"}),
336+
("attention_mask", {0: "batch", 1: "sequence"}),
337+
]
338+
)
339+
340+
@property
341+
def outputs(self) -> Mapping[str, Mapping[int, str]]:
342+
return OrderedDict(
343+
[
344+
("logits_per_image", {0: "batch"}),
345+
("logits_per_text", {0: "batch"}),
346+
("text_embeds", {0: "batch"}),
347+
("image_embeds", {0: "batch"}),
348+
]
349+
)
350+
351+
@property
352+
def atol_for_validation(self) -> float:
353+
return 1e-4
354+
355+
def generate_dummy_inputs(
356+
self,
357+
processor: "ProcessorMixin",
358+
framework: Optional["TensorType"] = None,
359+
) -> Mapping[str, Any]:
360+
361+
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
362+
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
363+
return {**text_input_dict, **image_input_dict}
364+
365+
@property
366+
def default_onnx_opset(self) -> int:
367+
return 14

src/transformers/models/clip/modeling_clip.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
6868

6969
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
7070
caption_loss = contrastive_loss(similarity)
71-
image_loss = contrastive_loss(similarity.T)
71+
image_loss = contrastive_loss(similarity.t())
7272
return (caption_loss + image_loss) / 2.0
7373

7474

@@ -660,7 +660,10 @@ def forward(
660660

661661
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
662662
# take features from the eot embedding (eot_token is the highest number in each sequence)
663-
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
663+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
664+
pooled_output = last_hidden_state[
665+
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
666+
]
664667

665668
if not return_dict:
666669
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
@@ -1050,7 +1053,7 @@ def forward(
10501053
# cosine similarity as logits
10511054
logit_scale = self.logit_scale.exp()
10521055
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1053-
logits_per_image = logits_per_text.T
1056+
logits_per_image = logits_per_text.t()
10541057

10551058
loss = None
10561059
if return_loss:

src/transformers/models/groupvit/modeling_groupvit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
7272
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->groupvit
7373
def groupvit_loss(similarity: torch.Tensor) -> torch.Tensor:
7474
caption_loss = contrastive_loss(similarity)
75-
image_loss = contrastive_loss(similarity.T)
75+
image_loss = contrastive_loss(similarity.t())
7676
return (caption_loss + image_loss) / 2.0
7777

7878

@@ -1132,7 +1132,10 @@ def forward(
11321132

11331133
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
11341134
# take features from the eot embedding (eot_token is the highest number in each sequence)
1135-
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
1135+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
1136+
pooled_output = last_hidden_state[
1137+
torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
1138+
]
11361139

11371140
if not return_dict:
11381141
return (last_hidden_state, pooled_output) + encoder_outputs[1:]

src/transformers/models/owlvit/modeling_owlvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
7171
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->owlvit
7272
def owlvit_loss(similarity: torch.Tensor) -> torch.Tensor:
7373
caption_loss = contrastive_loss(similarity)
74-
image_loss = contrastive_loss(similarity.T)
74+
image_loss = contrastive_loss(similarity.t())
7575
return (caption_loss + image_loss) / 2.0
7676

7777

src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
154154
# Copied from transformers.models.clip.modeling_clip.clip_loss
155155
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
156156
caption_loss = contrastive_loss(similarity)
157-
image_loss = contrastive_loss(similarity.T)
157+
image_loss = contrastive_loss(similarity.t())
158158
return (caption_loss + image_loss) / 2.0
159159

160160

src/transformers/onnx/features.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ class FeaturesManager:
201201
"question-answering",
202202
onnx_config_cls="models.camembert.CamembertOnnxConfig",
203203
),
204+
"clip": supported_features_mapping(
205+
"default",
206+
onnx_config_cls="models.clip.CLIPOnnxConfig",
207+
),
204208
"codegen": supported_features_mapping(
205209
"default",
206210
"causal-lm",

tests/onnx/test_onnx_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def test_values_override(self):
185185
("big-bird", "google/bigbird-roberta-base"),
186186
("ibert", "kssteven/ibert-roberta-base"),
187187
("camembert", "camembert-base"),
188+
("clip", "openai/clip-vit-base-patch32"),
188189
("convbert", "YituTech/conv-bert-base"),
189190
("codegen", "Salesforce/codegen-350M-multi"),
190191
("deberta", "microsoft/deberta-base"),

0 commit comments

Comments
 (0)