Skip to content

Commit 58018c2

Browse files
Feat: Implement keyword search in milvus
Signed-off-by: Varsha Prasad Narsing <[email protected]>
1 parent 28ca00d commit 58018c2

File tree

3 files changed

+237
-5
lines changed

3 files changed

+237
-5
lines changed

docs/source/providers/vector_io/milvus.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,12 @@ vector_io:
101101
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
102102
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
103103

104+
## Supported Search Modes
105+
106+
The Milvus provider supports both vector-based and keyword-based (full-text) search modes.
107+
108+
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in `RAGQueryConfig`. For more details on Milvus's implementation of keyword search modes, refer to the [Milvus documentation](https://milvus.io/docs/full_text_search_with_milvus.md).
109+
104110
## Documentation
105111
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
106112

llama_stack/providers/remote/vector_io/milvus/milvus.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Any
1313

1414
from numpy.typing import NDArray
15-
from pymilvus import MilvusClient
15+
from pymilvus import DataType, MilvusClient
1616

1717
from llama_stack.apis.inference import InterleavedContent
1818
from llama_stack.apis.vector_dbs import VectorDB
@@ -34,6 +34,8 @@ def __init__(self, client: MilvusClient, collection_name: str, consistency_level
3434
self.client = client
3535
self.collection_name = collection_name.replace("-", "_")
3636
self.consistency_level = consistency_level
37+
self.bm25 = None
38+
self.vectorizer = None
3739

3840
async def delete(self):
3941
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
@@ -44,11 +46,42 @@ async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
4446
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
4547
)
4648
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
49+
# Create schema for vector search
50+
schema = self.client.create_schema()
51+
schema.add_field(
52+
field_name="chunk_id",
53+
datatype=DataType.VARCHAR,
54+
is_primary=True,
55+
max_length=100,
56+
)
57+
schema.add_field(
58+
field_name="content",
59+
datatype=DataType.VARCHAR,
60+
max_length=65535,
61+
)
62+
schema.add_field(
63+
field_name="vector",
64+
datatype=DataType.FLOAT_VECTOR,
65+
dim=len(embeddings[0]),
66+
)
67+
schema.add_field(
68+
field_name="chunk_content",
69+
datatype=DataType.JSON,
70+
)
71+
72+
# Create indexes
73+
index_params = self.client.prepare_index_params()
74+
index_params.add_index(
75+
field_name="vector",
76+
index_type="FLAT",
77+
metric_type="COSINE",
78+
)
79+
4780
await asyncio.to_thread(
4881
self.client.create_collection,
4982
self.collection_name,
50-
dimension=len(embeddings[0]),
51-
auto_id=True,
83+
schema=schema,
84+
index_params=index_params,
5285
consistency_level=self.consistency_level,
5386
)
5487

@@ -59,6 +92,7 @@ async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
5992
data.append(
6093
{
6194
"chunk_id": chunk_id,
95+
"content": chunk.content,
6296
"vector": embedding,
6397
"chunk_content": chunk.model_dump(),
6498
}
@@ -78,9 +112,10 @@ async def query_vector(self, embedding: NDArray, k: int, score_threshold: float)
78112
self.client.search,
79113
collection_name=self.collection_name,
80114
data=[embedding],
115+
anns_field="vector",
81116
limit=k,
82117
output_fields=["*"],
83-
search_params={"params": {"radius": score_threshold}},
118+
search_params={"metric_type": "COSINE", "params": {"score_threshold": score_threshold}},
84119
)
85120
chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]]
86121
scores = [res["distance"] for res in search_res[0]]
@@ -92,7 +127,17 @@ async def query_keyword(
92127
k: int,
93128
score_threshold: float,
94129
) -> QueryChunksResponse:
95-
raise NotImplementedError("Keyword search is not supported in Milvus")
130+
# Simple text search using content field
131+
search_res = await asyncio.to_thread(
132+
self.client.query,
133+
collection_name=self.collection_name,
134+
filter=f'content like "%{query_string}%"',
135+
output_fields=["*"],
136+
limit=k,
137+
)
138+
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
139+
scores = [1.0] * len(chunks) # Simple binary score for text search
140+
return QueryChunksResponse(chunks=chunks, scores=scores)
96141

97142

98143
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from unittest.mock import MagicMock, patch
8+
9+
import numpy as np
10+
import pytest
11+
import pytest_asyncio
12+
13+
from llama_stack.apis.vector_io import QueryChunksResponse
14+
15+
# Mock the entire pymilvus module
16+
pymilvus_mock = MagicMock()
17+
pymilvus_mock.DataType = MagicMock()
18+
pymilvus_mock.MilvusClient = MagicMock
19+
20+
# Apply the mock before importing MilvusIndex
21+
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
22+
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
23+
24+
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
25+
# tests which are specific to this class. More general (API-level) tests should be placed in
26+
# tests/integration/vector_io/
27+
#
28+
# How to run this test:
29+
#
30+
# pytest tests/unit/providers/vector_io/test_milvus.py \
31+
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
32+
33+
MILVUS_PROVIDER = "milvus"
34+
35+
36+
@pytest_asyncio.fixture
37+
async def mock_milvus_client():
38+
"""Create a mock Milvus client with common method behaviors."""
39+
client = MagicMock()
40+
41+
# Mock collection operations
42+
client.has_collection.return_value = False # Initially no collection
43+
client.create_collection.return_value = None
44+
client.drop_collection.return_value = None
45+
46+
# Mock insert operation
47+
client.insert.return_value = {"insert_count": 10}
48+
49+
# Mock search operation - return mock results (data should be dict, not JSON string)
50+
client.search.return_value = [
51+
[
52+
{
53+
"id": 0,
54+
"distance": 0.1,
55+
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
56+
},
57+
{
58+
"id": 1,
59+
"distance": 0.2,
60+
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
61+
},
62+
]
63+
]
64+
65+
# Mock query operation for keyword search (data should be dict, not JSON string)
66+
client.query.return_value = [
67+
{
68+
"chunk_id": "chunk1",
69+
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
70+
"score": 0.9,
71+
},
72+
{
73+
"chunk_id": "chunk2",
74+
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
75+
"score": 0.8,
76+
},
77+
{
78+
"chunk_id": "chunk3",
79+
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
80+
"score": 0.7,
81+
},
82+
]
83+
84+
return client
85+
86+
87+
@pytest_asyncio.fixture
88+
async def milvus_index(mock_milvus_client):
89+
"""Create a MilvusIndex with mocked client."""
90+
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
91+
yield index
92+
# No real cleanup needed since we're using mocks
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
97+
# Setup: collection doesn't exist initially, then exists after creation
98+
mock_milvus_client.has_collection.side_effect = [False, True]
99+
100+
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
101+
102+
# Verify collection was created and data was inserted
103+
mock_milvus_client.create_collection.assert_called_once()
104+
mock_milvus_client.insert.assert_called_once()
105+
106+
# Verify the insert call had the right number of chunks
107+
insert_call = mock_milvus_client.insert.call_args
108+
assert len(insert_call[1]["data"]) == len(sample_chunks)
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_query_chunks_vector(
113+
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
114+
):
115+
# Setup: Add chunks first
116+
mock_milvus_client.has_collection.return_value = True
117+
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
118+
119+
# Test vector search
120+
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
121+
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
122+
123+
assert isinstance(response, QueryChunksResponse)
124+
assert len(response.chunks) == 2
125+
mock_milvus_client.search.assert_called_once()
126+
127+
128+
@pytest.mark.asyncio
129+
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
130+
# Setup: Add chunks first
131+
mock_milvus_client.has_collection.return_value = True
132+
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
133+
134+
# Test keyword search
135+
query_string = "Sentence 5"
136+
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
137+
138+
assert isinstance(response, QueryChunksResponse)
139+
assert len(response.chunks) == 3
140+
mock_milvus_client.query.assert_called_once()
141+
142+
# Test no results case
143+
mock_milvus_client.query.return_value = []
144+
response_no_results = await milvus_index.query_keyword(query_string="nonexistent", k=1, score_threshold=0.0)
145+
146+
assert isinstance(response_no_results, QueryChunksResponse)
147+
assert len(response_no_results.chunks) == 0
148+
149+
150+
@pytest.mark.asyncio
151+
async def test_query_chunks_keyword_search_k_greater_than_results(
152+
milvus_index, sample_chunks, sample_embeddings, mock_milvus_client
153+
):
154+
# Setup: Add chunks first
155+
mock_milvus_client.has_collection.return_value = True
156+
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
157+
158+
# Mock returning only 1 result even though k=5
159+
mock_milvus_client.query.return_value = [
160+
{
161+
"chunk_id": "chunk1",
162+
"chunk_content": {"content": "Sentence 1 from document 0", "metadata": {"document_id": "doc1"}},
163+
"score": 0.9,
164+
}
165+
]
166+
167+
query_str = "Sentence 1 from document 0"
168+
response = await milvus_index.query_keyword(query_string=query_str, k=5, score_threshold=0.0)
169+
170+
assert 0 < len(response.chunks) <= 4
171+
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks)
172+
173+
174+
@pytest.mark.asyncio
175+
async def test_delete_collection(milvus_index, mock_milvus_client):
176+
# Test collection deletion
177+
mock_milvus_client.has_collection.return_value = True
178+
179+
await milvus_index.delete()
180+
181+
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)

0 commit comments

Comments
 (0)