Skip to content

Commit 6f8f037

Browse files
lizzzcaiyuzisun
andauthored
Support model revision and tokenizer revision in huggingface server (kserve#3558)
* support model revision and tokenizer revision Signed-off-by: Lize Cai <[email protected]> * point to specified commit in test case Signed-off-by: Lize Cai <[email protected]> * format code Signed-off-by: Lize Cai <[email protected]> --------- Signed-off-by: Lize Cai <[email protected]> Signed-off-by: Dan Sun <[email protected]> Co-authored-by: Dan Sun <[email protected]>
1 parent f678243 commit 6f8f037

File tree

3 files changed

+57
-11
lines changed

3 files changed

+57
-11
lines changed

python/huggingfaceserver/huggingfaceserver/__main__.py

+10
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ def list_of_strings(arg):
3333
"--model_dir", required=False, default=None, help="A local path to the model binary"
3434
)
3535
parser.add_argument("--model_id", required=False, help="Huggingface model id")
36+
parser.add_argument(
37+
"--model_revision", required=False, default=None, help="Huggingface model revision"
38+
)
39+
parser.add_argument(
40+
"--tokenizer_revision",
41+
required=False,
42+
default=None,
43+
help="Huggingface tokenizer revision",
44+
)
3645
parser.add_argument(
3746
"--max_length", type=int, default=None, help="max sequence length for the tokenizer"
3847
)
@@ -74,6 +83,7 @@ def list_of_strings(arg):
7483
engine_args = None
7584
if _vllm and not args.disable_vllm:
7685
args.model = args.model_dir or args.model_id
86+
args.revision = args.model_revision
7787
engine_args = AsyncEngineArgs.from_cli_args(args)
7888
predictor_config = PredictorConfig(
7989
args.predictor_host,

python/huggingfaceserver/huggingfaceserver/model.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(
8383
self.model_dir = kwargs.get("model_dir", None)
8484
if not self.model_id and not self.model_dir:
8585
self.model_dir = "/mnt/models"
86+
self.model_revision = kwargs.get("model_revision", None)
87+
self.tokenizer_revision = kwargs.get("tokenizer_revision", None)
8688
self.do_lower_case = not kwargs.get("disable_lower_case", False)
8789
self.add_special_tokens = not kwargs.get("disable_special_tokens", False)
8890
self.max_length = kwargs.get("max_length", None)
@@ -111,8 +113,7 @@ def infer_task_from_model_architecture(model_config: str):
111113
)
112114

113115
@staticmethod
114-
def infer_vllm_supported_from_model_architecture(model_config_path: str):
115-
model_config = AutoConfig.from_pretrained(model_config_path)
116+
def infer_vllm_supported_from_model_architecture(model_config: str):
116117
architecture = model_config.architectures[0]
117118
model_cls = ModelRegistry.load_model_cls(architecture)
118119
if model_cls is None:
@@ -121,20 +122,24 @@ def infer_vllm_supported_from_model_architecture(model_config_path: str):
121122

122123
def load(self) -> bool:
123124
model_id_or_path = self.model_id
125+
revision = self.model_revision
126+
tokenizer_revision = self.tokenizer_revision
124127
if self.model_dir:
125128
model_id_or_path = pathlib.Path(Storage.download(self.model_dir))
126129
# TODO Read the mapping file, index to object name
130+
131+
model_config = AutoConfig.from_pretrained(model_id_or_path, revision=revision)
132+
127133
if self.use_vllm and self.device == torch.device("cuda"): # vllm needs gpu
128-
if self.infer_vllm_supported_from_model_architecture(model_id_or_path):
134+
if self.infer_vllm_supported_from_model_architecture(model_config):
135+
logger.info("supported model by vLLM")
129136
self.vllm_engine_args.tensor_parallel_size = torch.cuda.device_count()
130137
self.vllm_engine = AsyncLLMEngine.from_engine_args(
131138
self.vllm_engine_args
132139
)
133140
self.ready = True
134141
return self.ready
135142

136-
model_config = AutoConfig.from_pretrained(model_id_or_path)
137-
138143
if not self.task:
139144
self.task = self.infer_task_from_model_architecture(model_config)
140145

@@ -154,16 +159,19 @@ def load(self) -> bool:
154159
# https://github.com/huggingface/transformers/blob/1248f0925234f97da9eee98da2aa22f7b8dbeda1/src/transformers/generation/utils.py#L1376-L1388
155160
self.tokenizer = AutoTokenizer.from_pretrained(
156161
model_id_or_path,
162+
revision=tokenizer_revision,
157163
do_lower_case=self.do_lower_case,
158164
device_map=self.device_map,
159165
padding_side="left",
160166
)
161167
else:
162168
self.tokenizer = AutoTokenizer.from_pretrained(
163169
model_id_or_path,
170+
revision=tokenizer_revision,
164171
do_lower_case=self.do_lower_case,
165172
device_map=self.device_map,
166173
)
174+
167175
if not self.tokenizer.pad_token:
168176
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
169177
logger.info(f"successfully loaded tokenizer for task: {self.task}")
@@ -172,27 +180,27 @@ def load(self) -> bool:
172180
if not self.predictor_host:
173181
if self.task == MLTask.sequence_classification.value:
174182
self.model = AutoModelForSequenceClassification.from_pretrained(
175-
model_id_or_path, device_map=self.device_map
183+
model_id_or_path, revision=revision, device_map=self.device_map
176184
)
177185
elif self.task == MLTask.question_answering.value:
178186
self.model = AutoModelForQuestionAnswering.from_pretrained(
179-
model_id_or_path, device_map=self.device_map
187+
model_id_or_path, revision=revision, device_map=self.device_map
180188
)
181189
elif self.task == MLTask.token_classification.value:
182190
self.model = AutoModelForTokenClassification.from_pretrained(
183-
model_id_or_path, device_map=self.device_map
191+
model_id_or_path, revision=revision, device_map=self.device_map
184192
)
185193
elif self.task == MLTask.fill_mask.value:
186194
self.model = AutoModelForMaskedLM.from_pretrained(
187-
model_id_or_path, device_map=self.device_map
195+
model_id_or_path, revision=revision, device_map=self.device_map
188196
)
189197
elif self.task == MLTask.text_generation.value:
190198
self.model = AutoModelForCausalLM.from_pretrained(
191-
model_id_or_path, device_map=self.device_map
199+
model_id_or_path, revision=revision, device_map=self.device_map
192200
)
193201
elif self.task == MLTask.text2text_generation.value:
194202
self.model = AutoModelForSeq2SeqLM.from_pretrained(
195-
model_id_or_path, device_map=self.device_map
203+
model_id_or_path, revision=revision, device_map=self.device_map
196204
)
197205
else:
198206
raise ValueError(

python/huggingfaceserver/huggingfaceserver/test_model.py

+28
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,34 @@ def test_bert():
5454
assert response == {"predictions": ["paris", "france"]}
5555

5656

57+
def test_model_revision():
58+
# https://huggingface.co/google-bert/bert-base-uncased
59+
commit = "86b5e0934494bd15c9632b12f734a8a67f723594"
60+
model = HuggingfaceModel(
61+
"bert-base-uncased",
62+
{
63+
"model_id": "bert-base-uncased",
64+
"model_revision": commit,
65+
"tokenizer_revision": commit,
66+
"disable_lower_case": False,
67+
},
68+
)
69+
model.load()
70+
71+
response = asyncio.run(
72+
model(
73+
{
74+
"instances": [
75+
"The capital of France is [MASK].",
76+
"The capital of [MASK] is paris.",
77+
]
78+
},
79+
headers={},
80+
)
81+
)
82+
assert response == {"predictions": ["paris", "france"]}
83+
84+
5785
def test_bert_predictor_host(httpx_mock: HTTPXMock):
5886
httpx_mock.add_response(
5987
json={

0 commit comments

Comments
 (0)