Skip to content

Commit cd96829

Browse files
authored
FEAT: support qwen3-reranker (#3627)
1 parent 4c27c67 commit cd96829

File tree

4 files changed

+138
-7
lines changed

4 files changed

+138
-7
lines changed

xinference/model/rerank/core.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ...types import Document, DocumentObj, Rerank, RerankTokens
3232
from ..core import CacheableModelSpec, ModelDescription, VirtualEnvSettings
3333
from ..utils import is_model_cached
34+
from .utils import preprocess_sentence
3435

3536
logger = logging.getLogger(__name__)
3637

@@ -201,7 +202,10 @@ def load(self):
201202
)
202203
self._use_fp16 = True
203204

204-
if self._model_spec.type == "normal":
205+
if (
206+
self._model_spec.type == "normal"
207+
and "qwen3" not in self._model_spec.model_name.lower()
208+
):
205209
try:
206210
import sentence_transformers
207211
from sentence_transformers.cross_encoder import CrossEncoder
@@ -229,6 +233,65 @@ def load(self):
229233
)
230234
if self._use_fp16:
231235
self._model.model.half()
236+
elif "qwen3" in self._model_spec.model_name.lower():
237+
# qwen3-reranker
238+
# now we use transformers
239+
# TODO: support engines for rerank models
240+
try:
241+
from transformers import AutoModelForCausalLM, AutoTokenizer
242+
except ImportError:
243+
error_message = "Failed to import module 'transformers'"
244+
installation_guide = [
245+
"Please make sure 'transformers' is installed. ",
246+
"You can install it by `pip install transformers`\n",
247+
]
248+
249+
raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
250+
251+
tokenizer = AutoTokenizer.from_pretrained(
252+
self._model_path, padding_side="left"
253+
)
254+
model = self._model = AutoModelForCausalLM.from_pretrained(
255+
self._model_path
256+
).eval()
257+
max_length = getattr(self._model_spec, "max_tokens")
258+
259+
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
260+
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
261+
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
262+
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
263+
264+
def process_inputs(pairs):
265+
inputs = tokenizer(
266+
pairs,
267+
padding=False,
268+
truncation="longest_first",
269+
return_attention_mask=False,
270+
max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
271+
)
272+
for i, ele in enumerate(inputs["input_ids"]):
273+
inputs["input_ids"][i] = prefix_tokens + ele + suffix_tokens
274+
inputs = tokenizer.pad(
275+
inputs, padding=True, return_tensors="pt", max_length=max_length
276+
)
277+
for key in inputs:
278+
inputs[key] = inputs[key].to(model.device)
279+
return inputs
280+
281+
token_false_id = tokenizer.convert_tokens_to_ids("no")
282+
token_true_id = tokenizer.convert_tokens_to_ids("yes")
283+
284+
def compute_logits(inputs, **kwargs):
285+
batch_scores = model(**inputs).logits[:, -1, :]
286+
true_vector = batch_scores[:, token_true_id]
287+
false_vector = batch_scores[:, token_false_id]
288+
batch_scores = torch.stack([false_vector, true_vector], dim=1)
289+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
290+
scores = batch_scores[:, 1].exp().tolist()
291+
return scores
292+
293+
self.process_inputs = process_inputs
294+
self.compute_logits = compute_logits
232295
else:
233296
try:
234297
if self._model_spec.type == "LLM-based":
@@ -266,15 +329,17 @@ def rerank(
266329
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
267330
logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
268331

269-
from .utils import preprocess_sentence
270-
271332
pre_query = preprocess_sentence(
272333
query, kwargs.get("instruction", None), self._model_spec.model_name
273334
)
274335
sentence_combinations = [[pre_query, doc] for doc in documents]
275336
# reset n tokens
276337
self._model.model.n_tokens = 0
277-
if self._model_spec.type == "normal":
338+
if (
339+
self._model_spec.type == "normal"
340+
and "qwen3" not in self._model_spec.model_name.lower()
341+
):
342+
logger.debug("Passing processed sentences: %s", sentence_combinations)
278343
similarity_scores = self._model.predict(
279344
sentence_combinations,
280345
convert_to_numpy=False,
@@ -283,6 +348,23 @@ def rerank(
283348
).cpu()
284349
if similarity_scores.dtype == torch.bfloat16:
285350
similarity_scores = similarity_scores.float()
351+
elif "qwen3" in self._model_spec.model_name.lower():
352+
353+
def format_instruction(instruction, query, doc):
354+
if instruction is None:
355+
instruction = "Given a web search query, retrieve relevant passages that answer the query"
356+
output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
357+
instruction=instruction, query=query, doc=doc
358+
)
359+
return output
360+
361+
pairs = [
362+
format_instruction(kwargs.get("instruction", None), query, doc)
363+
for doc in documents
364+
]
365+
# Tokenize the input texts
366+
inputs = self.process_inputs(pairs)
367+
similarity_scores = self.compute_logits(inputs)
286368
else:
287369
# Related issue: https://github.com/xorbitsai/inference/issues/1775
288370
similarity_scores = self._model.compute_score(

xinference/model/rerank/model_spec.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,29 @@
6262
"max_tokens": 1024,
6363
"model_id": "openbmb/MiniCPM-Reranker",
6464
"model_revision": "5d2fd7345b6444c89d4c0fa59c92272888f3f2d0"
65+
},
66+
{
67+
"model_name": "Qwen3-Reranker-0.6B",
68+
"type": "normal",
69+
"language": ["en", "zh"],
70+
"max_tokens": 40960,
71+
"model_id": "Qwen/Qwen3-Reranker-0.6B",
72+
"model_revision": "6e9e69830b95c52b5fd889b7690dda3329508de3"
73+
},
74+
{
75+
"model_name": "Qwen3-Reranker-4B",
76+
"type": "normal",
77+
"language": ["en", "zh"],
78+
"max_tokens": 40960,
79+
"model_id": "Qwen/Qwen3-Reranker-4B",
80+
"model_revision": "f16fc5d5d2b9b1d0db8280929242745d79794ef5"
81+
},
82+
{
83+
"model_name": "Qwen3-Reranker-8B",
84+
"type": "normal",
85+
"language": ["en", "zh"],
86+
"max_tokens": 40960,
87+
"model_id": "Qwen/Qwen3-Reranker-8B",
88+
"model_revision": "5fa94080caafeaa45a15d11f969d7978e087a3db"
6589
}
6690
]

xinference/model/rerank/model_spec_modelscope.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,29 @@
5757
"max_tokens": 1024,
5858
"model_id": "OpenBMB/MiniCPM-Reranker",
5959
"model_hub": "modelscope"
60+
},
61+
{
62+
"model_name": "Qwen3-Reranker-0.6B",
63+
"type": "normal",
64+
"language": ["en", "zh"],
65+
"max_tokens": 40960,
66+
"model_id": "Qwen/Qwen3-Reranker-0.6B",
67+
"model_hub": "modelscope"
68+
},
69+
{
70+
"model_name": "Qwen3-Reranker-4B",
71+
"type": "normal",
72+
"language": ["en", "zh"],
73+
"max_tokens": 40960,
74+
"model_id": "Qwen/Qwen3-Reranker-4B",
75+
"model_hub": "modelscope"
76+
},
77+
{
78+
"model_name": "Qwen3-Reranker-8B",
79+
"type": "normal",
80+
"language": ["en", "zh"],
81+
"max_tokens": 40960,
82+
"model_id": "Qwen/Qwen3-Reranker-8B",
83+
"model_hub": "modelscope"
6084
}
6185
]

xinference/model/rerank/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any
14+
from typing import TYPE_CHECKING, Any
1515

16-
from .core import RerankModelSpec
16+
if TYPE_CHECKING:
17+
from .core import RerankModelSpec
1718

1819

19-
def get_model_version(rerank_model: RerankModelSpec) -> str:
20+
def get_model_version(rerank_model: "RerankModelSpec") -> str:
2021
return rerank_model.model_name
2122

2223

0 commit comments

Comments
 (0)