Skip to content

Commit 1802076

Browse files
authored
[fix] Fix semantic_search_usearch with 'binary' (#2989)
* Fix semantic_search_usearch with 'binary' * Add b1 support back, but with ubinary
1 parent 72d5649 commit 1802076

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

examples/applications/embedding-quantization/semantic_search_usearch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sentence_transformers.quantization import quantize_embeddings, semantic_search_usearch
77

88
# 1. Load the quora corpus with questions
9-
dataset = load_dataset("quora", split="train").map(
9+
dataset = load_dataset("quora", split="train", trust_remote_code=True).map(
1010
lambda batch: {"text": [text for sample in batch["questions"] for text in sample["text"]]},
1111
batched=True,
1212
remove_columns=["questions", "is_duplicate"],
@@ -26,7 +26,7 @@
2626
# 4. Choose a target precision for the corpus embeddings
2727
corpus_precision = "binary"
2828
# Valid options are: "float32", "uint8", "int8", "ubinary", and "binary"
29-
# But usearch only supports "float32", "int8", and "binary"
29+
# But usearch only supports "float32", "int8", "binary" and "ubinary"
3030

3131
# 5. Encode the corpus
3232
full_corpus_embeddings = model.encode(corpus, normalize_embeddings=True, show_progress_bar=True)

sentence_transformers/quantization.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def semantic_search_usearch(
216216
`corpus_embeddings` or `corpus_index` should be used, not
217217
both.
218218
corpus_precision: Precision of the corpus embeddings. The
219-
options are "float32", "int8", or "binary". Default is
220-
"float32".
219+
options are "float32", "int8", "ubinary" or "binary". Default
220+
is "float32".
221221
top_k: Number of top results to retrieve. Default is 10.
222222
ranges: Ranges for quantization of embeddings. This is only used
223223
for int8 quantization, where the ranges refers to the
@@ -263,8 +263,8 @@ def semantic_search_usearch(
263263
raise ValueError("Only corpus_embeddings or corpus_index should be used, not both.")
264264
if corpus_embeddings is None and corpus_index is None:
265265
raise ValueError("Either corpus_embeddings or corpus_index should be used.")
266-
if corpus_precision not in ["float32", "int8", "binary"]:
267-
raise ValueError('corpus_precision must be "float32", "int8", or "binary" for usearch')
266+
if corpus_precision not in ["float32", "int8", "ubinary", "binary"]:
267+
raise ValueError('corpus_precision must be "float32", "int8", "ubinary", "binary" for usearch')
268268

269269
# If corpus_index is not provided, create a new index
270270
if corpus_index is None:
@@ -284,6 +284,12 @@ def semantic_search_usearch(
284284
corpus_index = Index(
285285
ndim=corpus_embeddings.shape[1],
286286
metric="hamming",
287+
dtype="i8",
288+
)
289+
elif corpus_precision == "ubinary":
290+
corpus_index = Index(
291+
ndim=corpus_embeddings.shape[1] * 8,
292+
metric="hamming",
287293
dtype="b1",
288294
)
289295
corpus_index.add(np.arange(len(corpus_embeddings)), corpus_embeddings)
@@ -331,7 +337,7 @@ def semantic_search_usearch(
331337
if rescore_embeddings is not None:
332338
top_k_embeddings = np.array([corpus_index.get(query_indices) for query_indices in indices])
333339
# If the corpus precision is binary, we need to unpack the bits
334-
if corpus_precision == "binary":
340+
if corpus_precision in ("ubinary", "binary"):
335341
top_k_embeddings = np.unpackbits(top_k_embeddings.astype(np.uint8), axis=-1)
336342
top_k_embeddings = top_k_embeddings.astype(int)
337343

0 commit comments

Comments
 (0)