Skip to content

Commit a1eb133

Browse files
committed
feat: add similarity search with distance score
1 parent 691eafb commit a1eb133

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

src/langchain_google_firestore/vectorstores.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def _similarity_search(
240240
query: List[float],
241241
k: int = DEFAULT_TOP_K,
242242
filters: Optional[BaseFilter] = None,
243+
with_scores: Optional[bool] = False,
243244
) -> List[DocumentSnapshot]:
244245
_filters = filters or self.filters
245246

@@ -253,6 +254,7 @@ def _similarity_search(
253254
query_vector=Vector(query),
254255
distance_measure=self.distance_strategy,
255256
limit=k,
257+
distance_result_field= 'distance' if with_scores else None
256258
)
257259

258260
return results.get()
@@ -413,6 +415,14 @@ def max_marginal_relevance_search_by_vector(
413415
)
414416
return [convert_firestore_document(doc_results[i]) for i in mmr_doc_indexes]
415417

418+
def similarity_search_with_score(self, query, k = 4,filters: Optional[BaseFilter] = None, **kwargs):
419+
docs = self._similarity_search(
420+
self.embedding_service.embed_query(query), k, filters=filters,with_scores=True
421+
)
422+
return [
423+
(convert_firestore_document(doc, page_content_fields=[self.content_field]),doc.to_dict()["distance"])
424+
for doc in docs
425+
]
416426
@classmethod
417427
def from_texts(
418428
cls: Type[FirestoreVectorStore],

tests/test_vectorstores.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,76 @@ def test_firestore_max_marginal_relevance_by_vector(
495495
test_case.assertEqual(len(results), k)
496496

497497

498+
def test_firestore_similarity_search_with_score(
499+
test_case: TestCase,
500+
test_collection: str,
501+
client,
502+
embeddings: FakeEmbeddings,
503+
):
504+
"""
505+
An end-to-end test for similarity search with score in FirestoreVectorStore.
506+
"""
507+
508+
# Create FirestoreVectorStore instance
509+
firestore_store = FirestoreVectorStore(test_collection, embeddings, client=client)
510+
511+
texts = ["test1", "test2"]
512+
k = 2
513+
514+
# Add vectors to Firestore
515+
firestore_store.add_texts(texts, ids=["1", "2"])
516+
517+
# Perform similarity search with score
518+
results = firestore_store.similarity_search_with_score("test1", k)
519+
520+
# Verify that the search results are as expected
521+
test_case.assertEqual(len(results), k)
522+
523+
# Check that each result is a tuple with a Document and a score
524+
for result in results:
525+
test_case.assertTrue(isinstance(result, tuple))
526+
test_case.assertEqual(len(result), 2)
527+
test_case.assertTrue(isinstance(result[0], Document))
528+
test_case.assertTrue(isinstance(result[1], float))
529+
530+
531+
def test_firestore_similarity_search_with_score_with_filters(
532+
test_case: TestCase,
533+
test_collection: str,
534+
client: firestore.Client,
535+
embeddings: FakeEmbeddings,
536+
):
537+
"""
538+
An end-to-end test for similarity search with score in FirestoreVectorStore with filters.
539+
Requires an index on the filter field in Firestore.
540+
"""
541+
542+
# Create FirestoreVectorStore instance
543+
firestore_store = FirestoreVectorStore(test_collection, embeddings, client=client)
544+
545+
# Add vectors to Firestore
546+
firestore_store.add_texts(
547+
["test1", "test2"],
548+
ids=["1", "2"],
549+
metadatas=[{"foo": "bar"}, {"foo": "baz"}],
550+
)
551+
552+
# Perform similarity search with score and filter
553+
results = firestore_store.similarity_search_with_score(
554+
"test1", k=2, filters=FieldFilter("metadata.foo", "==", "bar")
555+
)
556+
557+
# Verify that the search results are as expected with the filter applied
558+
test_case.assertEqual(len(results), 1)
559+
560+
# Check that the result is a tuple with a Document and a score
561+
doc, score = results[0]
562+
test_case.assertTrue(isinstance(doc, Document))
563+
test_case.assertTrue(isinstance(score, float))
564+
test_case.assertEqual(doc.page_content, "test1")
565+
test_case.assertEqual(doc.metadata["metadata"]["foo"], "bar")
566+
567+
498568
def test_firestore_from_texts(
499569
test_case: TestCase,
500570
test_collection: str,

0 commit comments

Comments
 (0)