Skip to content

Commit 529719a

Browse files
feat(Qdrant): introducing use_sparse_embeddings parameters for document store to make sparse embeddings non breaking change. Need more testing
1 parent 79d0d52 commit 529719a

File tree

5 files changed

+188
-72
lines changed

5 files changed

+188
-72
lines changed

integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py

+38-21
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,26 @@ def documents_to_batch(
2121
documents: List[Document],
2222
*,
2323
embedding_field: str,
24+
use_sparse_embeddings: bool,
2425
sparse_embedding_field: str,
2526
) -> List[rest.PointStruct]:
2627
points = []
2728
for document in documents:
2829
payload = document.to_dict(flatten=False)
29-
vector = {}
30+
if use_sparse_embeddings:
31+
vector = {}
3032

31-
dense_vector = payload.pop(embedding_field, None)
32-
if dense_vector is not None:
33-
vector[DENSE_VECTORS_NAME] = dense_vector
33+
dense_vector = payload.pop(embedding_field, None)
34+
if dense_vector is not None:
35+
vector[DENSE_VECTORS_NAME] = dense_vector
3436

35-
sparse_vector = payload.pop(sparse_embedding_field, None)
36-
if sparse_vector is not None:
37-
sparse_vector_instance = rest.SparseVector(**sparse_vector)
38-
vector[SPARSE_VECTORS_NAME] = sparse_vector_instance
37+
sparse_vector = payload.pop(sparse_embedding_field, None)
38+
if sparse_vector is not None:
39+
sparse_vector_instance = rest.SparseVector(**sparse_vector)
40+
vector[SPARSE_VECTORS_NAME] = sparse_vector_instance
3941

42+
if not use_sparse_embeddings:
43+
vector = payload.pop(embedding_field) or {}
4044
_id = self.convert_id(payload.get("id"))
4145

4246
point = rest.PointStruct(
@@ -61,25 +65,38 @@ def convert_id(self, _id: str) -> str:
6165

6266

6367
class QdrantToHaystack:
64-
def __init__(self, content_field: str, name_field: str, embedding_field: str, sparse_embedding_field: str):
68+
def __init__(
69+
self,
70+
content_field: str,
71+
name_field: str,
72+
embedding_field: str,
73+
use_sparse_embeddings: bool, # noqa: FBT001
74+
sparse_embedding_field: str,
75+
):
6576
self.content_field = content_field
6677
self.name_field = name_field
6778
self.embedding_field = embedding_field
79+
self.use_sparse_embeddings = use_sparse_embeddings
6880
self.sparse_embedding_field = sparse_embedding_field
6981

7082
def point_to_document(self, point: QdrantPoint) -> Document:
7183
payload = {**point.payload}
72-
if hasattr(point, "vector") and point.vector is not None and DENSE_VECTORS_NAME in point.vector:
73-
payload["embedding"] = point.vector[DENSE_VECTORS_NAME]
74-
else:
75-
payload["embedding"] = None
7684
payload["score"] = point.score if hasattr(point, "score") else None
77-
if hasattr(point, "vector") and point.vector is not None and SPARSE_VECTORS_NAME in point.vector:
78-
parse_vector_dict = {
79-
"indices": point.vector[SPARSE_VECTORS_NAME].indices,
80-
"values": point.vector[SPARSE_VECTORS_NAME].values,
81-
}
82-
payload["sparse_embedding"] = parse_vector_dict
83-
else:
84-
payload["sparse_embedding"] = None
85+
if not self.use_sparse_embeddings:
86+
payload["embedding"] = point.vector if hasattr(point, "vector") else None
87+
88+
if self.use_sparse_embeddings:
89+
if hasattr(point, "vector") and point.vector is not None and DENSE_VECTORS_NAME in point.vector:
90+
payload["embedding"] = point.vector[DENSE_VECTORS_NAME]
91+
else:
92+
payload["embedding"] = None
93+
94+
if hasattr(point, "vector") and point.vector is not None and SPARSE_VECTORS_NAME in point.vector:
95+
parse_vector_dict = {
96+
"indices": point.vector[SPARSE_VECTORS_NAME].indices,
97+
"values": point.vector[SPARSE_VECTORS_NAME].values,
98+
}
99+
payload["sparse_embedding"] = parse_vector_dict
100+
else:
101+
payload["sparse_embedding"] = None
85102
return Document.from_dict(payload)

integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py

+112-43
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
content_field: str = "content",
7272
name_field: str = "name",
7373
embedding_field: str = "embedding",
74+
use_sparse_embeddings: bool = False, # noqa: FBT001, FBT002
7475
sparse_embedding_field: str = "sparse_embedding",
7576
similarity: str = "cosine",
7677
return_embedding: bool = False, # noqa: FBT001, FBT002
@@ -140,13 +141,16 @@ def __init__(
140141
self.payload_fields_to_index = payload_fields_to_index
141142

142143
# Make sure the collection is properly set up
143-
self._set_up_collection(index, embedding_dim, recreate_index, similarity, on_disk, payload_fields_to_index)
144+
self._set_up_collection(
145+
index, embedding_dim, recreate_index, similarity, use_sparse_embeddings, on_disk, payload_fields_to_index
146+
)
144147

145148
self.embedding_dim = embedding_dim
146149
self.on_disk = on_disk
147150
self.content_field = content_field
148151
self.name_field = name_field
149152
self.embedding_field = embedding_field
153+
self.use_sparse_embeddings = use_sparse_embeddings
150154
self.sparse_embedding_field = sparse_embedding_field
151155
self.similarity = similarity
152156
self.index = index
@@ -155,7 +159,9 @@ def __init__(
155159
self.duplicate_documents = duplicate_documents
156160
self.qdrant_filter_converter = QdrantFilterConverter()
157161
self.haystack_to_qdrant_converter = HaystackToQdrant()
158-
self.qdrant_to_haystack = QdrantToHaystack(content_field, name_field, embedding_field, sparse_embedding_field)
162+
self.qdrant_to_haystack = QdrantToHaystack(
163+
content_field, name_field, embedding_field, use_sparse_embeddings, sparse_embedding_field
164+
)
159165
self.write_batch_size = write_batch_size
160166
self.scroll_size = scroll_size
161167

@@ -196,7 +202,7 @@ def write_documents(
196202
if not isinstance(doc, Document):
197203
msg = f"DocumentStore.write_documents() expects a list of Documents but got an element of {type(doc)}."
198204
raise ValueError(msg)
199-
self._set_up_collection(self.index, self.embedding_dim, False, self.similarity)
205+
self._set_up_collection(self.index, self.embedding_dim, False, self.similarity, self.use_sparse_embeddings)
200206

201207
if len(documents) == 0:
202208
logger.warning("Calling QdrantDocumentStore.write_documents() with empty list")
@@ -214,6 +220,7 @@ def write_documents(
214220
batch = self.haystack_to_qdrant_converter.documents_to_batch(
215221
document_batch,
216222
embedding_field=self.embedding_field,
223+
use_sparse_embeddings=self.use_sparse_embeddings,
217224
sparse_embedding_field=self.sparse_embedding_field,
218225
)
219226

@@ -309,10 +316,17 @@ def query_by_sparse(
309316
scale_score: bool = True, # noqa: FBT001, FBT002
310317
return_embedding: bool = False, # noqa: FBT001, FBT002
311318
) -> List[Document]:
319+
320+
if not self.use_sparse_embeddings:
321+
message = (
322+
"Error: tried to query by sparse vector with a Qdrant "
323+
"Document Store initialized with use_sparse_embeddings=False"
324+
)
325+
raise ValueError(message)
326+
312327
qdrant_filters = self.qdrant_filter_converter.convert(filters)
313328
query_indices = query_sparse_embedding.indices
314329
query_values = query_sparse_embedding.values
315-
316330
points = self.client.search(
317331
collection_name=self.index,
318332
query_vector=rest.NamedSparseVector(
@@ -326,7 +340,6 @@ def query_by_sparse(
326340
limit=top_k,
327341
with_vectors=return_embedding,
328342
)
329-
330343
results = [self.qdrant_to_haystack.point_to_document(point) for point in points]
331344
if scale_score:
332345
for document in results:
@@ -345,17 +358,25 @@ def query_by_embedding(
345358
) -> List[Document]:
346359
qdrant_filters = self.qdrant_filter_converter.convert(filters)
347360

348-
points = self.client.search(
349-
collection_name=self.index,
350-
query_vector=rest.NamedVector(
351-
name=DENSE_VECTORS_NAME,
352-
vector=query_embedding,
353-
),
354-
query_filter=qdrant_filters,
355-
limit=top_k,
356-
with_vectors=return_embedding,
357-
)
358-
361+
if self.use_sparse_embeddings:
362+
points = self.client.search(
363+
collection_name=self.index,
364+
query_vector=rest.NamedVector(
365+
name=DENSE_VECTORS_NAME,
366+
vector=query_embedding,
367+
),
368+
query_filter=qdrant_filters,
369+
limit=top_k,
370+
with_vectors=return_embedding,
371+
)
372+
if not self.use_sparse_embeddings:
373+
points = self.client.search(
374+
collection_name=self.index,
375+
query_vector=query_embedding,
376+
query_filter=qdrant_filters,
377+
limit=top_k,
378+
with_vectors=return_embedding,
379+
)
359380
results = [self.qdrant_to_haystack.point_to_document(point) for point in points]
360381
if scale_score:
361382
for document in results:
@@ -397,6 +418,7 @@ def _set_up_collection(
397418
embedding_dim: int,
398419
recreate_collection: bool, # noqa: FBT001
399420
similarity: str,
421+
use_sparse_embeddings: bool, # noqa: FBT001
400422
on_disk: bool = False, # noqa: FBT001, FBT002
401423
payload_fields_to_index: Optional[List[dict]] = None,
402424
):
@@ -405,7 +427,7 @@ def _set_up_collection(
405427
if recreate_collection:
406428
# There is no need to verify the current configuration of that
407429
# collection. It might be just recreated again.
408-
self._recreate_collection(collection_name, distance, embedding_dim, on_disk)
430+
self._recreate_collection(collection_name, distance, embedding_dim, on_disk, use_sparse_embeddings)
409431
# Create Payload index if payload_fields_to_index is provided
410432
self._create_payload_index(collection_name, payload_fields_to_index)
411433
return
@@ -421,12 +443,33 @@ def _set_up_collection(
421443
# Qdrant local raises ValueError if the collection is not found, but
422444
# with the remote server UnexpectedResponse / RpcError is raised.
423445
# Until that's unified, we need to catch both.
424-
self._recreate_collection(collection_name, distance, embedding_dim, on_disk)
446+
self._recreate_collection(collection_name, distance, embedding_dim, on_disk, use_sparse_embeddings)
425447
# Create Payload index if payload_fields_to_index is provided
426448
self._create_payload_index(collection_name, payload_fields_to_index)
427449
return
428-
current_distance = collection_info.config.params.vectors[DENSE_VECTORS_NAME].distance
429-
current_vector_size = collection_info.config.params.vectors[DENSE_VECTORS_NAME].size
450+
if self.use_sparse_embeddings:
451+
current_distance = collection_info.config.params.vectors[DENSE_VECTORS_NAME].distance
452+
current_vector_size = collection_info.config.params.vectors[DENSE_VECTORS_NAME].size
453+
if not self.use_sparse_embeddings:
454+
current_distance = collection_info.config.params.vectors.distance
455+
current_vector_size = collection_info.config.params.vectors.size
456+
457+
if self.use_sparse_embeddings and not isinstance(collection_info.config.params.vectors, dict):
458+
msg = (
459+
f"Collection '{collection_name}' already exists in Qdrant, "
460+
f"but it has been originaly created without sparse embedding vectors."
461+
f"If you want to use that collection, either set `use_sparse_embeddings=False` "
462+
f"or run a migration script "
463+
f"to use Named Dense Vectors (`text-sparse`) and Named Sparse Vectors (`text-dense`)."
464+
)
465+
raise ValueError(msg)
466+
if not self.use_sparse_embeddings and isinstance(collection_info.config.params.vectors, dict):
467+
msg = (
468+
f"Collection '{collection_name}' already exists in Qdrant, "
469+
f"but it has been originaly created with sparse embedding vectors."
470+
f"If you want to use that collection, please set `use_sparse_embeddings=True`"
471+
)
472+
raise ValueError(msg)
430473

431474
if current_distance != distance:
432475
msg = (
@@ -446,33 +489,59 @@ def _set_up_collection(
446489
)
447490
raise ValueError(msg)
448491

449-
def _recreate_collection(self, collection_name: str, distance, embedding_dim: int, on_disk: bool): # noqa: FBT001
450-
self.client.recreate_collection(
451-
collection_name=collection_name,
452-
vectors_config={
453-
DENSE_VECTORS_NAME: rest.VectorParams(
492+
def _recreate_collection(
493+
self,
494+
collection_name: str,
495+
distance,
496+
embedding_dim: int,
497+
on_disk: bool, # noqa: FBT001
498+
use_sparse_embeddings: bool, # noqa: FBT001
499+
):
500+
if use_sparse_embeddings:
501+
self.client.recreate_collection(
502+
collection_name=collection_name,
503+
vectors_config={
504+
DENSE_VECTORS_NAME: rest.VectorParams(
505+
size=embedding_dim,
506+
on_disk=on_disk,
507+
distance=distance,
508+
),
509+
},
510+
sparse_vectors_config={
511+
SPARSE_VECTORS_NAME: rest.SparseVectorParams(
512+
index=rest.SparseIndexParams(
513+
on_disk=on_disk,
514+
)
515+
)
516+
},
517+
shard_number=self.shard_number,
518+
replication_factor=self.replication_factor,
519+
write_consistency_factor=self.write_consistency_factor,
520+
on_disk_payload=self.on_disk_payload,
521+
hnsw_config=self.hnsw_config,
522+
optimizers_config=self.optimizers_config,
523+
wal_config=self.wal_config,
524+
quantization_config=self.quantization_config,
525+
init_from=self.init_from,
526+
)
527+
if not use_sparse_embeddings:
528+
self.client.recreate_collection(
529+
collection_name=collection_name,
530+
vectors_config=rest.VectorParams(
454531
size=embedding_dim,
455532
on_disk=on_disk,
456533
distance=distance,
457534
),
458-
},
459-
sparse_vectors_config={
460-
SPARSE_VECTORS_NAME: rest.SparseVectorParams(
461-
index=rest.SparseIndexParams(
462-
on_disk=on_disk,
463-
)
464-
)
465-
},
466-
shard_number=self.shard_number,
467-
replication_factor=self.replication_factor,
468-
write_consistency_factor=self.write_consistency_factor,
469-
on_disk_payload=self.on_disk_payload,
470-
hnsw_config=self.hnsw_config,
471-
optimizers_config=self.optimizers_config,
472-
wal_config=self.wal_config,
473-
quantization_config=self.quantization_config,
474-
init_from=self.init_from,
475-
)
535+
shard_number=self.shard_number,
536+
replication_factor=self.replication_factor,
537+
write_consistency_factor=self.write_consistency_factor,
538+
on_disk_payload=self.on_disk_payload,
539+
hnsw_config=self.hnsw_config,
540+
optimizers_config=self.optimizers_config,
541+
wal_config=self.wal_config,
542+
quantization_config=self.quantization_config,
543+
init_from=self.init_from,
544+
)
476545

477546
def _handle_duplicate_documents(
478547
self,

integrations/qdrant/tests/test_converters.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ def haystack_to_qdrant() -> HaystackToQdrant:
1515

1616

1717
@pytest.fixture
18-
def qdrant_to_haystack() -> QdrantToHaystack:
18+
def qdrant_to_haystack(request) -> QdrantToHaystack:
1919
return QdrantToHaystack(
2020
content_field=CONTENT_FIELD,
2121
name_field=NAME_FIELD,
2222
embedding_field=EMBEDDING_FIELD,
23+
use_sparse_embeddings=request.param,
2324
sparse_embedding_field=SPARSE_EMBEDDING_FIELD,
2425
)
2526

@@ -30,7 +31,8 @@ def test_convert_id_is_deterministic(haystack_to_qdrant: HaystackToQdrant):
3031
assert first_id == second_id
3132

3233

33-
def test_point_to_document_reverts_proper_structure_from_record(
34+
@pytest.mark.parametrize("qdrant_to_haystack", [True], indirect=True)
35+
def test_point_to_document_reverts_proper_structure_from_record_with_sparse(
3436
qdrant_to_haystack: QdrantToHaystack,
3537
):
3638
point = rest.Record(
@@ -56,3 +58,29 @@ def test_point_to_document_reverts_proper_structure_from_record(
5658
assert {"indices": [7, 1024, 367], "values": [0.1, 0.98, 0.33]} == document.sparse_embedding.to_dict()
5759
assert {"test_field": 1} == document.meta
5860
assert 0.0 == np.sum(np.array([1.0, 0.0, 0.0, 0.0]) - document.embedding)
61+
62+
63+
@pytest.mark.parametrize("qdrant_to_haystack", [False], indirect=True)
64+
def test_point_to_document_reverts_proper_structure_from_record_without_sparse(
65+
qdrant_to_haystack: QdrantToHaystack,
66+
):
67+
point = rest.Record(
68+
id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7",
69+
payload={
70+
"id": "my-id",
71+
"id_hash_keys": ["content"],
72+
"content": "Lorem ipsum",
73+
"content_type": "text",
74+
"meta": {
75+
"test_field": 1,
76+
},
77+
},
78+
vector=[1.0, 0.0, 0.0, 0.0],
79+
)
80+
document = qdrant_to_haystack.point_to_document(point)
81+
assert "my-id" == document.id
82+
assert "Lorem ipsum" == document.content
83+
assert "text" == document.content_type
84+
assert document.sparse_embedding is None
85+
assert {"test_field": 1} == document.meta
86+
assert 0.0 == np.sum(np.array([1.0, 0.0, 0.0, 0.0]) - document.embedding)

0 commit comments

Comments
 (0)