@@ -71,6 +71,7 @@ def __init__(
71
71
content_field : str = "content" ,
72
72
name_field : str = "name" ,
73
73
embedding_field : str = "embedding" ,
74
+ use_sparse_embeddings : bool = False , # noqa: FBT001, FBT002
74
75
sparse_embedding_field : str = "sparse_embedding" ,
75
76
similarity : str = "cosine" ,
76
77
return_embedding : bool = False , # noqa: FBT001, FBT002
@@ -140,13 +141,16 @@ def __init__(
140
141
self .payload_fields_to_index = payload_fields_to_index
141
142
142
143
# Make sure the collection is properly set up
143
- self ._set_up_collection (index , embedding_dim , recreate_index , similarity , on_disk , payload_fields_to_index )
144
+ self ._set_up_collection (
145
+ index , embedding_dim , recreate_index , similarity , use_sparse_embeddings , on_disk , payload_fields_to_index
146
+ )
144
147
145
148
self .embedding_dim = embedding_dim
146
149
self .on_disk = on_disk
147
150
self .content_field = content_field
148
151
self .name_field = name_field
149
152
self .embedding_field = embedding_field
153
+ self .use_sparse_embeddings = use_sparse_embeddings
150
154
self .sparse_embedding_field = sparse_embedding_field
151
155
self .similarity = similarity
152
156
self .index = index
@@ -155,7 +159,9 @@ def __init__(
155
159
self .duplicate_documents = duplicate_documents
156
160
self .qdrant_filter_converter = QdrantFilterConverter ()
157
161
self .haystack_to_qdrant_converter = HaystackToQdrant ()
158
- self .qdrant_to_haystack = QdrantToHaystack (content_field , name_field , embedding_field , sparse_embedding_field )
162
+ self .qdrant_to_haystack = QdrantToHaystack (
163
+ content_field , name_field , embedding_field , use_sparse_embeddings , sparse_embedding_field
164
+ )
159
165
self .write_batch_size = write_batch_size
160
166
self .scroll_size = scroll_size
161
167
@@ -196,7 +202,7 @@ def write_documents(
196
202
if not isinstance (doc , Document ):
197
203
msg = f"DocumentStore.write_documents() expects a list of Documents but got an element of { type (doc )} ."
198
204
raise ValueError (msg )
199
- self ._set_up_collection (self .index , self .embedding_dim , False , self .similarity )
205
+ self ._set_up_collection (self .index , self .embedding_dim , False , self .similarity , self . use_sparse_embeddings )
200
206
201
207
if len (documents ) == 0 :
202
208
logger .warning ("Calling QdrantDocumentStore.write_documents() with empty list" )
@@ -214,6 +220,7 @@ def write_documents(
214
220
batch = self .haystack_to_qdrant_converter .documents_to_batch (
215
221
document_batch ,
216
222
embedding_field = self .embedding_field ,
223
+ use_sparse_embeddings = self .use_sparse_embeddings ,
217
224
sparse_embedding_field = self .sparse_embedding_field ,
218
225
)
219
226
@@ -309,10 +316,17 @@ def query_by_sparse(
309
316
scale_score : bool = True , # noqa: FBT001, FBT002
310
317
return_embedding : bool = False , # noqa: FBT001, FBT002
311
318
) -> List [Document ]:
319
+
320
+ if not self .use_sparse_embeddings :
321
+ message = (
322
+ "Error: tried to query by sparse vector with a Qdrant "
323
+ "Document Store initialized with use_sparse_embeddings=False"
324
+ )
325
+ raise ValueError (message )
326
+
312
327
qdrant_filters = self .qdrant_filter_converter .convert (filters )
313
328
query_indices = query_sparse_embedding .indices
314
329
query_values = query_sparse_embedding .values
315
-
316
330
points = self .client .search (
317
331
collection_name = self .index ,
318
332
query_vector = rest .NamedSparseVector (
@@ -326,7 +340,6 @@ def query_by_sparse(
326
340
limit = top_k ,
327
341
with_vectors = return_embedding ,
328
342
)
329
-
330
343
results = [self .qdrant_to_haystack .point_to_document (point ) for point in points ]
331
344
if scale_score :
332
345
for document in results :
@@ -345,17 +358,25 @@ def query_by_embedding(
345
358
) -> List [Document ]:
346
359
qdrant_filters = self .qdrant_filter_converter .convert (filters )
347
360
348
- points = self .client .search (
349
- collection_name = self .index ,
350
- query_vector = rest .NamedVector (
351
- name = DENSE_VECTORS_NAME ,
352
- vector = query_embedding ,
353
- ),
354
- query_filter = qdrant_filters ,
355
- limit = top_k ,
356
- with_vectors = return_embedding ,
357
- )
358
-
361
+ if self .use_sparse_embeddings :
362
+ points = self .client .search (
363
+ collection_name = self .index ,
364
+ query_vector = rest .NamedVector (
365
+ name = DENSE_VECTORS_NAME ,
366
+ vector = query_embedding ,
367
+ ),
368
+ query_filter = qdrant_filters ,
369
+ limit = top_k ,
370
+ with_vectors = return_embedding ,
371
+ )
372
+ if not self .use_sparse_embeddings :
373
+ points = self .client .search (
374
+ collection_name = self .index ,
375
+ query_vector = query_embedding ,
376
+ query_filter = qdrant_filters ,
377
+ limit = top_k ,
378
+ with_vectors = return_embedding ,
379
+ )
359
380
results = [self .qdrant_to_haystack .point_to_document (point ) for point in points ]
360
381
if scale_score :
361
382
for document in results :
@@ -397,6 +418,7 @@ def _set_up_collection(
397
418
embedding_dim : int ,
398
419
recreate_collection : bool , # noqa: FBT001
399
420
similarity : str ,
421
+ use_sparse_embeddings : bool , # noqa: FBT001
400
422
on_disk : bool = False , # noqa: FBT001, FBT002
401
423
payload_fields_to_index : Optional [List [dict ]] = None ,
402
424
):
@@ -405,7 +427,7 @@ def _set_up_collection(
405
427
if recreate_collection :
406
428
# There is no need to verify the current configuration of that
407
429
# collection. It might be just recreated again.
408
- self ._recreate_collection (collection_name , distance , embedding_dim , on_disk )
430
+ self ._recreate_collection (collection_name , distance , embedding_dim , on_disk , use_sparse_embeddings )
409
431
# Create Payload index if payload_fields_to_index is provided
410
432
self ._create_payload_index (collection_name , payload_fields_to_index )
411
433
return
@@ -421,12 +443,33 @@ def _set_up_collection(
421
443
# Qdrant local raises ValueError if the collection is not found, but
422
444
# with the remote server UnexpectedResponse / RpcError is raised.
423
445
# Until that's unified, we need to catch both.
424
- self ._recreate_collection (collection_name , distance , embedding_dim , on_disk )
446
+ self ._recreate_collection (collection_name , distance , embedding_dim , on_disk , use_sparse_embeddings )
425
447
# Create Payload index if payload_fields_to_index is provided
426
448
self ._create_payload_index (collection_name , payload_fields_to_index )
427
449
return
428
- current_distance = collection_info .config .params .vectors [DENSE_VECTORS_NAME ].distance
429
- current_vector_size = collection_info .config .params .vectors [DENSE_VECTORS_NAME ].size
450
+ if self .use_sparse_embeddings :
451
+ current_distance = collection_info .config .params .vectors [DENSE_VECTORS_NAME ].distance
452
+ current_vector_size = collection_info .config .params .vectors [DENSE_VECTORS_NAME ].size
453
+ if not self .use_sparse_embeddings :
454
+ current_distance = collection_info .config .params .vectors .distance
455
+ current_vector_size = collection_info .config .params .vectors .size
456
+
457
+ if self .use_sparse_embeddings and not isinstance (collection_info .config .params .vectors , dict ):
458
+ msg = (
459
+ f"Collection '{ collection_name } ' already exists in Qdrant, "
460
+ f"but it has been originaly created without sparse embedding vectors."
461
+ f"If you want to use that collection, either set `use_sparse_embeddings=False` "
462
+ f"or run a migration script "
463
+ f"to use Named Dense Vectors (`text-sparse`) and Named Sparse Vectors (`text-dense`)."
464
+ )
465
+ raise ValueError (msg )
466
+ if not self .use_sparse_embeddings and isinstance (collection_info .config .params .vectors , dict ):
467
+ msg = (
468
+ f"Collection '{ collection_name } ' already exists in Qdrant, "
469
+ f"but it has been originaly created with sparse embedding vectors."
470
+ f"If you want to use that collection, please set `use_sparse_embeddings=True`"
471
+ )
472
+ raise ValueError (msg )
430
473
431
474
if current_distance != distance :
432
475
msg = (
@@ -446,33 +489,59 @@ def _set_up_collection(
446
489
)
447
490
raise ValueError (msg )
448
491
449
- def _recreate_collection (self , collection_name : str , distance , embedding_dim : int , on_disk : bool ): # noqa: FBT001
450
- self .client .recreate_collection (
451
- collection_name = collection_name ,
452
- vectors_config = {
453
- DENSE_VECTORS_NAME : rest .VectorParams (
492
+ def _recreate_collection (
493
+ self ,
494
+ collection_name : str ,
495
+ distance ,
496
+ embedding_dim : int ,
497
+ on_disk : bool , # noqa: FBT001
498
+ use_sparse_embeddings : bool , # noqa: FBT001
499
+ ):
500
+ if use_sparse_embeddings :
501
+ self .client .recreate_collection (
502
+ collection_name = collection_name ,
503
+ vectors_config = {
504
+ DENSE_VECTORS_NAME : rest .VectorParams (
505
+ size = embedding_dim ,
506
+ on_disk = on_disk ,
507
+ distance = distance ,
508
+ ),
509
+ },
510
+ sparse_vectors_config = {
511
+ SPARSE_VECTORS_NAME : rest .SparseVectorParams (
512
+ index = rest .SparseIndexParams (
513
+ on_disk = on_disk ,
514
+ )
515
+ )
516
+ },
517
+ shard_number = self .shard_number ,
518
+ replication_factor = self .replication_factor ,
519
+ write_consistency_factor = self .write_consistency_factor ,
520
+ on_disk_payload = self .on_disk_payload ,
521
+ hnsw_config = self .hnsw_config ,
522
+ optimizers_config = self .optimizers_config ,
523
+ wal_config = self .wal_config ,
524
+ quantization_config = self .quantization_config ,
525
+ init_from = self .init_from ,
526
+ )
527
+ if not use_sparse_embeddings :
528
+ self .client .recreate_collection (
529
+ collection_name = collection_name ,
530
+ vectors_config = rest .VectorParams (
454
531
size = embedding_dim ,
455
532
on_disk = on_disk ,
456
533
distance = distance ,
457
534
),
458
- },
459
- sparse_vectors_config = {
460
- SPARSE_VECTORS_NAME : rest .SparseVectorParams (
461
- index = rest .SparseIndexParams (
462
- on_disk = on_disk ,
463
- )
464
- )
465
- },
466
- shard_number = self .shard_number ,
467
- replication_factor = self .replication_factor ,
468
- write_consistency_factor = self .write_consistency_factor ,
469
- on_disk_payload = self .on_disk_payload ,
470
- hnsw_config = self .hnsw_config ,
471
- optimizers_config = self .optimizers_config ,
472
- wal_config = self .wal_config ,
473
- quantization_config = self .quantization_config ,
474
- init_from = self .init_from ,
475
- )
535
+ shard_number = self .shard_number ,
536
+ replication_factor = self .replication_factor ,
537
+ write_consistency_factor = self .write_consistency_factor ,
538
+ on_disk_payload = self .on_disk_payload ,
539
+ hnsw_config = self .hnsw_config ,
540
+ optimizers_config = self .optimizers_config ,
541
+ wal_config = self .wal_config ,
542
+ quantization_config = self .quantization_config ,
543
+ init_from = self .init_from ,
544
+ )
476
545
477
546
def _handle_duplicate_documents (
478
547
self ,
0 commit comments