31
31
from ...types import Document , DocumentObj , Rerank , RerankTokens
32
32
from ..core import CacheableModelSpec , ModelDescription , VirtualEnvSettings
33
33
from ..utils import is_model_cached
34
+ from .utils import preprocess_sentence
34
35
35
36
logger = logging .getLogger (__name__ )
36
37
@@ -201,7 +202,10 @@ def load(self):
201
202
)
202
203
self ._use_fp16 = True
203
204
204
- if self ._model_spec .type == "normal" :
205
+ if (
206
+ self ._model_spec .type == "normal"
207
+ and "qwen3" not in self ._model_spec .model_name .lower ()
208
+ ):
205
209
try :
206
210
import sentence_transformers
207
211
from sentence_transformers .cross_encoder import CrossEncoder
@@ -229,6 +233,65 @@ def load(self):
229
233
)
230
234
if self ._use_fp16 :
231
235
self ._model .model .half ()
236
+ elif "qwen3" in self ._model_spec .model_name .lower ():
237
+ # qwen3-reranker
238
+ # now we use transformers
239
+ # TODO: support engines for rerank models
240
+ try :
241
+ from transformers import AutoModelForCausalLM , AutoTokenizer
242
+ except ImportError :
243
+ error_message = "Failed to import module 'transformers'"
244
+ installation_guide = [
245
+ "Please make sure 'transformers' is installed. " ,
246
+ "You can install it by `pip install transformers`\n " ,
247
+ ]
248
+
249
+ raise ImportError (f"{ error_message } \n \n { '' .join (installation_guide )} " )
250
+
251
+ tokenizer = AutoTokenizer .from_pretrained (
252
+ self ._model_path , padding_side = "left"
253
+ )
254
+ model = self ._model = AutoModelForCausalLM .from_pretrained (
255
+ self ._model_path
256
+ ).eval ()
257
+ max_length = getattr (self ._model_spec , "max_tokens" )
258
+
259
+ prefix = '<|im_start|>system\n Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n <|im_start|>user\n '
260
+ suffix = "<|im_end|>\n <|im_start|>assistant\n <think>\n \n </think>\n \n "
261
+ prefix_tokens = tokenizer .encode (prefix , add_special_tokens = False )
262
+ suffix_tokens = tokenizer .encode (suffix , add_special_tokens = False )
263
+
264
+ def process_inputs (pairs ):
265
+ inputs = tokenizer (
266
+ pairs ,
267
+ padding = False ,
268
+ truncation = "longest_first" ,
269
+ return_attention_mask = False ,
270
+ max_length = max_length - len (prefix_tokens ) - len (suffix_tokens ),
271
+ )
272
+ for i , ele in enumerate (inputs ["input_ids" ]):
273
+ inputs ["input_ids" ][i ] = prefix_tokens + ele + suffix_tokens
274
+ inputs = tokenizer .pad (
275
+ inputs , padding = True , return_tensors = "pt" , max_length = max_length
276
+ )
277
+ for key in inputs :
278
+ inputs [key ] = inputs [key ].to (model .device )
279
+ return inputs
280
+
281
+ token_false_id = tokenizer .convert_tokens_to_ids ("no" )
282
+ token_true_id = tokenizer .convert_tokens_to_ids ("yes" )
283
+
284
+ def compute_logits (inputs , ** kwargs ):
285
+ batch_scores = model (** inputs ).logits [:, - 1 , :]
286
+ true_vector = batch_scores [:, token_true_id ]
287
+ false_vector = batch_scores [:, token_false_id ]
288
+ batch_scores = torch .stack ([false_vector , true_vector ], dim = 1 )
289
+ batch_scores = torch .nn .functional .log_softmax (batch_scores , dim = 1 )
290
+ scores = batch_scores [:, 1 ].exp ().tolist ()
291
+ return scores
292
+
293
+ self .process_inputs = process_inputs
294
+ self .compute_logits = compute_logits
232
295
else :
233
296
try :
234
297
if self ._model_spec .type == "LLM-based" :
@@ -266,15 +329,17 @@ def rerank(
266
329
raise ValueError ("rerank hasn't support `max_chunks_per_doc` parameter." )
267
330
logger .info ("Rerank with kwargs: %s, model: %s" , kwargs , self ._model )
268
331
269
- from .utils import preprocess_sentence
270
-
271
332
pre_query = preprocess_sentence (
272
333
query , kwargs .get ("instruction" , None ), self ._model_spec .model_name
273
334
)
274
335
sentence_combinations = [[pre_query , doc ] for doc in documents ]
275
336
# reset n tokens
276
337
self ._model .model .n_tokens = 0
277
- if self ._model_spec .type == "normal" :
338
+ if (
339
+ self ._model_spec .type == "normal"
340
+ and "qwen3" not in self ._model_spec .model_name .lower ()
341
+ ):
342
+ logger .debug ("Passing processed sentences: %s" , sentence_combinations )
278
343
similarity_scores = self ._model .predict (
279
344
sentence_combinations ,
280
345
convert_to_numpy = False ,
@@ -283,6 +348,23 @@ def rerank(
283
348
).cpu ()
284
349
if similarity_scores .dtype == torch .bfloat16 :
285
350
similarity_scores = similarity_scores .float ()
351
+ elif "qwen3" in self ._model_spec .model_name .lower ():
352
+
353
+ def format_instruction (instruction , query , doc ):
354
+ if instruction is None :
355
+ instruction = "Given a web search query, retrieve relevant passages that answer the query"
356
+ output = "<Instruct>: {instruction}\n <Query>: {query}\n <Document>: {doc}" .format (
357
+ instruction = instruction , query = query , doc = doc
358
+ )
359
+ return output
360
+
361
+ pairs = [
362
+ format_instruction (kwargs .get ("instruction" , None ), query , doc )
363
+ for doc in documents
364
+ ]
365
+ # Tokenize the input texts
366
+ inputs = self .process_inputs (pairs )
367
+ similarity_scores = self .compute_logits (inputs )
286
368
else :
287
369
# Related issue: https://github.com/xorbitsai/inference/issues/1775
288
370
similarity_scores = self ._model .compute_score (
0 commit comments