@@ -83,6 +83,8 @@ def __init__(
83
83
self .model_dir = kwargs .get ("model_dir" , None )
84
84
if not self .model_id and not self .model_dir :
85
85
self .model_dir = "/mnt/models"
86
+ self .model_revision = kwargs .get ("model_revision" , None )
87
+ self .tokenizer_revision = kwargs .get ("tokenizer_revision" , None )
86
88
self .do_lower_case = not kwargs .get ("disable_lower_case" , False )
87
89
self .add_special_tokens = not kwargs .get ("disable_special_tokens" , False )
88
90
self .max_length = kwargs .get ("max_length" , None )
@@ -111,8 +113,7 @@ def infer_task_from_model_architecture(model_config: str):
111
113
)
112
114
113
115
@staticmethod
114
- def infer_vllm_supported_from_model_architecture (model_config_path : str ):
115
- model_config = AutoConfig .from_pretrained (model_config_path )
116
+ def infer_vllm_supported_from_model_architecture (model_config : str ):
116
117
architecture = model_config .architectures [0 ]
117
118
model_cls = ModelRegistry .load_model_cls (architecture )
118
119
if model_cls is None :
@@ -121,20 +122,24 @@ def infer_vllm_supported_from_model_architecture(model_config_path: str):
121
122
122
123
def load (self ) -> bool :
123
124
model_id_or_path = self .model_id
125
+ revision = self .model_revision
126
+ tokenizer_revision = self .tokenizer_revision
124
127
if self .model_dir :
125
128
model_id_or_path = pathlib .Path (Storage .download (self .model_dir ))
126
129
# TODO Read the mapping file, index to object name
130
+
131
+ model_config = AutoConfig .from_pretrained (model_id_or_path , revision = revision )
132
+
127
133
if self .use_vllm and self .device == torch .device ("cuda" ): # vllm needs gpu
128
- if self .infer_vllm_supported_from_model_architecture (model_id_or_path ):
134
+ if self .infer_vllm_supported_from_model_architecture (model_config ):
135
+ logger .info ("supported model by vLLM" )
129
136
self .vllm_engine_args .tensor_parallel_size = torch .cuda .device_count ()
130
137
self .vllm_engine = AsyncLLMEngine .from_engine_args (
131
138
self .vllm_engine_args
132
139
)
133
140
self .ready = True
134
141
return self .ready
135
142
136
- model_config = AutoConfig .from_pretrained (model_id_or_path )
137
-
138
143
if not self .task :
139
144
self .task = self .infer_task_from_model_architecture (model_config )
140
145
@@ -154,16 +159,19 @@ def load(self) -> bool:
154
159
# https://github.com/huggingface/transformers/blob/1248f0925234f97da9eee98da2aa22f7b8dbeda1/src/transformers/generation/utils.py#L1376-L1388
155
160
self .tokenizer = AutoTokenizer .from_pretrained (
156
161
model_id_or_path ,
162
+ revision = tokenizer_revision ,
157
163
do_lower_case = self .do_lower_case ,
158
164
device_map = self .device_map ,
159
165
padding_side = "left" ,
160
166
)
161
167
else :
162
168
self .tokenizer = AutoTokenizer .from_pretrained (
163
169
model_id_or_path ,
170
+ revision = tokenizer_revision ,
164
171
do_lower_case = self .do_lower_case ,
165
172
device_map = self .device_map ,
166
173
)
174
+
167
175
if not self .tokenizer .pad_token :
168
176
self .tokenizer .add_special_tokens ({"pad_token" : "[PAD]" })
169
177
logger .info (f"successfully loaded tokenizer for task: { self .task } " )
@@ -172,27 +180,27 @@ def load(self) -> bool:
172
180
if not self .predictor_host :
173
181
if self .task == MLTask .sequence_classification .value :
174
182
self .model = AutoModelForSequenceClassification .from_pretrained (
175
- model_id_or_path , device_map = self .device_map
183
+ model_id_or_path , revision = revision , device_map = self .device_map
176
184
)
177
185
elif self .task == MLTask .question_answering .value :
178
186
self .model = AutoModelForQuestionAnswering .from_pretrained (
179
- model_id_or_path , device_map = self .device_map
187
+ model_id_or_path , revision = revision , device_map = self .device_map
180
188
)
181
189
elif self .task == MLTask .token_classification .value :
182
190
self .model = AutoModelForTokenClassification .from_pretrained (
183
- model_id_or_path , device_map = self .device_map
191
+ model_id_or_path , revision = revision , device_map = self .device_map
184
192
)
185
193
elif self .task == MLTask .fill_mask .value :
186
194
self .model = AutoModelForMaskedLM .from_pretrained (
187
- model_id_or_path , device_map = self .device_map
195
+ model_id_or_path , revision = revision , device_map = self .device_map
188
196
)
189
197
elif self .task == MLTask .text_generation .value :
190
198
self .model = AutoModelForCausalLM .from_pretrained (
191
- model_id_or_path , device_map = self .device_map
199
+ model_id_or_path , revision = revision , device_map = self .device_map
192
200
)
193
201
elif self .task == MLTask .text2text_generation .value :
194
202
self .model = AutoModelForSeq2SeqLM .from_pretrained (
195
- model_id_or_path , device_map = self .device_map
203
+ model_id_or_path , revision = revision , device_map = self .device_map
196
204
)
197
205
else :
198
206
raise ValueError (
0 commit comments