Skip to content

Commit f4984fb

Browse files
committed
Refactoring to allow both list of lists, and single lists, for embeddings, due to provider differences
1 parent 8453040 commit f4984fb

File tree

5 files changed

+67
-15
lines changed

5 files changed

+67
-15
lines changed

backend/src/app/services/vector_db/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,16 @@ async def ensure_collection_exists(self) -> None:
8282
async def get_embeddings(
8383
self, texts: Union[str, List[str]]
8484
) -> List[List[float]]:
85-
"""Get embeddings for the given text(s) using the LLM service."""
85+
"""Get embeddings for the given text(s) using the embedding service."""
8686
if isinstance(texts, str):
8787
texts = [texts]
8888
return await self.embedding_service.get_embeddings(texts)
8989

90+
async def get_single_embedding(self, text: str) -> List[float]:
91+
"""Get a single embedding for the given text."""
92+
embeddings = await self.get_embeddings(text)
93+
return embeddings[0]
94+
9095
async def prepare_chunks(
9196
self, document_id: str, chunks: List[Document]
9297
) -> List[Dict[str, Any]]:

backend/src/app/services/vector_db/milvus_service.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,8 @@ async def vector_search(
155155
# Search for each query
156156
for query in queries:
157157
logger.info("Generating embedding.")
158-
159-
# Embed the query
160-
embedded_query = await self.get_embeddings(query)
158+
# Use get_single_embedding but wrap result in list for Milvus
159+
embedded_query = [await self.get_single_embedding(query)]
161160

162161
logger.info("Searching...")
163162

@@ -323,7 +322,7 @@ def count_keywords(text: str, keywords: List[str]) -> int:
323322
)
324323

325324
# Embed the query
326-
embedded_query = await self.get_embeddings(query)
325+
embedded_query = [await self.get_single_embedding(query)]
327326

328327
try:
329328
# First, let's check if there are any vectors for this document_id

backend/src/app/services/vector_db/qdrant_service.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ async def vector_search(
7575

7676
for query in queries:
7777
logger.info("Generating embedding.")
78-
embedded_query = await self.get_embeddings(query)
78+
embedded_query = await self.get_single_embedding(query)
7979
logger.info("Searching...")
8080

8181
query_response = self.client.query_points(
@@ -162,7 +162,7 @@ def count_keywords(text: str, keywords: List[str]) -> int:
162162
reverse=True,
163163
)
164164

165-
embedded_query = await self.get_embeddings(query)
165+
embedded_query = await self.get_single_embedding(query)
166166
logger.info("Running semantic similarity search.")
167167

168168
semantic_response = self.client.query_points(
@@ -194,8 +194,6 @@ def count_keywords(text: str, keywords: List[str]) -> int:
194194
combined_chunks, key=lambda chunk: chunk["chunk_number"]
195195
)
196196

197-
# Optionally, for each chunk, retrieve neighbouring chunks to ensure full context is retrieved
198-
199197
# Eliminate duplicate chunks
200198
seen_chunks = set()
201199
formatted_output = []

backend/tests/test_service_vector_db_milvus.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self, embedding_service, llm_service, settings):
1313
self.embedding_service = embedding_service
1414
self.llm_service = llm_service
1515
self.settings = settings
16-
self.client = Mock() # Use regular Mock instead of AsyncMock
16+
self.client = Mock()
1717

1818
# Set up synchronous return values
1919
self.client.has_collection.return_value = True
@@ -36,12 +36,17 @@ async def upsert_vectors(self, vectors):
3636
}
3737

3838
async def vector_search(self, queries, document_id):
39+
# Mock using get_single_embedding
40+
for query in queries:
41+
_ = await self.get_single_embedding(query)
3942
return VectorResponseSchema(message="success", chunks=[])
4043

4144
async def keyword_search(self, query, document_id, keywords):
4245
return VectorResponseSchema(message="success", chunks=[])
4346

4447
async def hybrid_search(self, query, document_id, rules):
48+
# Mock using get_single_embedding
49+
_ = await self.get_single_embedding(query)
4550
return VectorResponseSchema(
4651
message="Query processed successfully.", chunks=[]
4752
)
@@ -107,3 +112,27 @@ async def test_delete_document(vector_db_service):
107112

108113
assert result["status"] == "success"
109114
assert result["message"] == "Document deleted successfully."
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_get_single_embedding(vector_db_service):
119+
# Reset the mock before the test
120+
vector_db_service.embedding_service.get_embeddings.reset_mock()
121+
122+
# Mock the embedding service to return a known value
123+
vector_db_service.embedding_service.get_embeddings.return_value = [
124+
[0.1, 0.2, 0.3]
125+
]
126+
127+
# Test getting a single embedding
128+
result = await vector_db_service.get_single_embedding("test text")
129+
130+
# Verify the result
131+
assert isinstance(result, list)
132+
assert len(result) == 3 # Length of our mock embedding
133+
assert result == [0.1, 0.2, 0.3]
134+
135+
# Verify the embedding service was called correctly
136+
vector_db_service.embedding_service.get_embeddings.assert_called_once_with(
137+
["test text"]
138+
)

backend/tests/test_service_vector_db_qdrant.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ async def test_upsert_vectors(qdrant_service):
7777

7878
@pytest.mark.asyncio
7979
async def test_vector_search(qdrant_service, mock_embeddings_service):
80-
mock_embeddings_service.get_embeddings.return_value = [0.1, 0.2]
80+
mock_embeddings_service.get_embeddings.return_value = [[0.1, 0.2]]
8181

8282
result = await qdrant_service.vector_search(["test query"], "test_doc")
8383

@@ -88,11 +88,8 @@ async def test_vector_search(qdrant_service, mock_embeddings_service):
8888

8989
@pytest.mark.asyncio
9090
async def test_hybrid_search(qdrant_service, mock_embeddings_service):
91-
# Mock the embedding service response
92-
mock_embeddings_service.get_embeddings.return_value = [0.1, 0.2]
91+
mock_embeddings_service.get_embeddings.return_value = [[0.1, 0.2]]
9392

94-
# Mock the extract_keywords method directly on the qdrant_service
95-
# since it's a method of QdrantService, not CompletionService
9693
with patch.object(
9794
qdrant_service,
9895
"extract_keywords",
@@ -134,3 +131,27 @@ async def test_delete_document(qdrant_service):
134131
async def test_keyword_search_not_implemented(qdrant_service):
135132
with pytest.raises(NotImplementedError):
136133
await qdrant_service.keyword_search("query", "doc_id", ["keyword"])
134+
135+
136+
@pytest.mark.asyncio
137+
async def test_get_single_embedding(qdrant_service):
138+
# Reset the mock before the test
139+
qdrant_service.embedding_service.get_embeddings.reset_mock()
140+
141+
# Mock the embedding service to return a known value
142+
qdrant_service.embedding_service.get_embeddings.return_value = [
143+
[0.1, 0.2, 0.3]
144+
]
145+
146+
# Test getting a single embedding
147+
result = await qdrant_service.get_single_embedding("test text")
148+
149+
# Verify the result
150+
assert isinstance(result, list)
151+
assert len(result) == 3
152+
assert result == [0.1, 0.2, 0.3]
153+
154+
# Verify the embedding service was called correctly
155+
qdrant_service.embedding_service.get_embeddings.assert_called_once_with(
156+
["test text"]
157+
)

0 commit comments

Comments
 (0)