Skip to content

Commit 28bf298

Browse files
authored
new: update type hints (#64)
* new: update type hints * fix: do not pass location and path to qdrant client, and do not accept them together * new: update settings tests * fix: revert removal of local path
1 parent b657656 commit 28bf298

File tree

6 files changed

+83
-78
lines changed

6 files changed

+83
-78
lines changed

src/mcp_server_qdrant/embeddings/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from abc import ABC, abstractmethod
2-
from typing import List
32

43

54
class EmbeddingProvider(ABC):
65
"""Abstract base class for embedding providers."""
76

87
@abstractmethod
9-
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
8+
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
109
"""Embed a list of documents into vectors."""
1110
pass
1211

1312
@abstractmethod
14-
async def embed_query(self, query: str) -> List[float]:
13+
async def embed_query(self, query: str) -> list[float]:
1514
"""Embed a query into a vector."""
1615
pass
1716

src/mcp_server_qdrant/embeddings/fastembed.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
from typing import List
32

43
from fastembed import TextEmbedding
54
from fastembed.common.model_description import DenseModelDescription
@@ -17,7 +16,7 @@ def __init__(self, model_name: str):
1716
self.model_name = model_name
1817
self.embedding_model = TextEmbedding(model_name)
1918

20-
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
19+
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
2120
"""Embed a list of documents into vectors."""
2221
# Run in a thread pool since FastEmbed is synchronous
2322
loop = asyncio.get_event_loop()
@@ -26,7 +25,7 @@ async def embed_documents(self, documents: List[str]) -> List[List[float]]:
2625
)
2726
return [embedding.tolist() for embedding in embeddings]
2827

29-
async def embed_query(self, query: str) -> List[float]:
28+
async def embed_query(self, query: str) -> list[float]:
3029
"""Embed a query into a vector."""
3130
# Run in a thread pool since FastEmbed is synchronous
3231
loop = asyncio.get_event_loop()

src/mcp_server_qdrant/mcp_server.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import logging
3-
from typing import Annotated, Any, List, Optional
3+
from typing import Annotated, Any
44

55
from fastmcp import Context, FastMCP
66
from pydantic import Field
@@ -76,7 +76,7 @@ async def store(
7676
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
7777
# handle the optional parameter correctly.
7878
metadata: Annotated[
79-
Optional[Metadata],
79+
Metadata | None,
8080
Field(
8181
description="Extra metadata stored along with memorised information. Any json is accepted."
8282
),
@@ -106,14 +106,15 @@ async def find(
106106
collection_name: Annotated[
107107
str, Field(description="The collection to search in")
108108
],
109-
query_filter: Optional[ArbitraryFilter] = None,
110-
) -> List[str]:
109+
query_filter: ArbitraryFilter | None = None,
110+
) -> list[str]:
111111
"""
112112
Find memories in Qdrant.
113113
:param ctx: The context for the request.
114114
:param query: The query to use for the search.
115115
:param collection_name: The name of the collection to search in, optional. If not provided,
116116
the default collection is used.
117+
:param query_filter: The filter to apply to the query.
117118
:return: A list of entries found.
118119
"""
119120

@@ -123,10 +124,6 @@ async def find(
123124
query_filter = models.Filter(**query_filter) if query_filter else None
124125

125126
await ctx.debug(f"Finding results for query {query}")
126-
if collection_name:
127-
await ctx.debug(
128-
f"Overriding the collection name with {collection_name}"
129-
)
130127

131128
entries = await self.qdrant_connector.search(
132129
query,

src/mcp_server_qdrant/qdrant.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import uuid
3-
from typing import Any, Dict, Optional
3+
from typing import Any
44

55
from pydantic import BaseModel
66
from qdrant_client import AsyncQdrantClient, models
@@ -10,9 +10,8 @@
1010

1111
logger = logging.getLogger(__name__)
1212

13-
Metadata = Dict[str, Any]
14-
15-
ArbitraryFilter = Dict[str, Any]
13+
Metadata = dict[str, Any]
14+
ArbitraryFilter = dict[str, Any]
1615

1716

1817
class Entry(BaseModel):
@@ -21,7 +20,7 @@ class Entry(BaseModel):
2120
"""
2221

2322
content: str
24-
metadata: Optional[Metadata] = None
23+
metadata: Metadata | None = None
2524

2625

2726
class QdrantConnector:
@@ -37,12 +36,12 @@ class QdrantConnector:
3736

3837
def __init__(
3938
self,
40-
qdrant_url: Optional[str],
41-
qdrant_api_key: Optional[str],
42-
collection_name: Optional[str],
39+
qdrant_url: str | None,
40+
qdrant_api_key: str | None,
41+
collection_name: str | None,
4342
embedding_provider: EmbeddingProvider,
44-
qdrant_local_path: Optional[str] = None,
45-
field_indexes: Optional[dict[str, models.PayloadSchemaType]] = None,
43+
qdrant_local_path: str | None = None,
44+
field_indexes: dict[str, models.PayloadSchemaType] | None = None,
4645
):
4746
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
4847
self._qdrant_api_key = qdrant_api_key
@@ -61,7 +60,7 @@ async def get_collection_names(self) -> list[str]:
6160
response = await self._client.get_collections()
6261
return [collection.name for collection in response.collections]
6362

64-
async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
63+
async def store(self, entry: Entry, *, collection_name: str | None = None):
6564
"""
6665
Store some information in the Qdrant collection, along with the specified metadata.
6766
:param entry: The entry to store in the Qdrant collection.
@@ -95,16 +94,18 @@ async def search(
9594
self,
9695
query: str,
9796
*,
98-
collection_name: Optional[str] = None,
97+
collection_name: str | None = None,
9998
limit: int = 10,
100-
query_filter: Optional[models.Filter] = None,
99+
query_filter: models.Filter | None = None,
101100
) -> list[Entry]:
102101
"""
103102
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
104103
:param query: The query to use for the search.
105104
:param collection_name: The name of the collection to search in, optional. If not provided,
106105
the default collection is used.
107106
:param limit: The maximum number of entries to return.
107+
:param query_filter: The filter to apply to the query, if any.
108+
108109
:return: A list of entries found.
109110
"""
110111
collection_name = collection_name or self._default_collection_name

src/mcp_server_qdrant/settings.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import Literal, Optional
1+
from typing import Literal
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel, Field, model_validator
44
from pydantic_settings import BaseSettings
55

66
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
@@ -56,7 +56,7 @@ class FilterableField(BaseModel):
5656
field_type: Literal["keyword", "integer", "float", "boolean"] = Field(
5757
description="The type of the field"
5858
)
59-
condition: Optional[Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"]] = (
59+
condition: Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"] | None = (
6060
Field(
6161
default=None,
6262
description=(
@@ -76,18 +76,16 @@ class QdrantSettings(BaseSettings):
7676
Configuration for the Qdrant connector.
7777
"""
7878

79-
location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL")
80-
api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY")
81-
collection_name: Optional[str] = Field(
79+
location: str | None = Field(default=None, validation_alias="QDRANT_URL")
80+
api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY")
81+
collection_name: str | None = Field(
8282
default=None, validation_alias="COLLECTION_NAME"
8383
)
84-
local_path: Optional[str] = Field(
85-
default=None, validation_alias="QDRANT_LOCAL_PATH"
86-
)
84+
local_path: str | None = Field(default=None, validation_alias="QDRANT_LOCAL_PATH")
8785
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
8886
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")
8987

90-
filterable_fields: Optional[list[FilterableField]] = Field(default=None)
88+
filterable_fields: list[FilterableField] | None = Field(default=None)
9189

9290
allow_arbitrary_filter: bool = Field(
9391
default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER"
@@ -106,3 +104,12 @@ def filterable_fields_dict_with_conditions(self) -> dict[str, FilterableField]:
106104
for field in self.filterable_fields
107105
if field.condition is not None
108106
}
107+
108+
@model_validator(mode="after")
109+
def check_local_path_conflict(self) -> "QdrantSettings":
110+
if self.local_path:
111+
if self.location is not None or self.api_key is not None:
112+
raise ValueError(
113+
"If 'local_path' is set, 'location' and 'api_key' must be None."
114+
)
115+
return self

tests/test_settings.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import os
2-
from unittest.mock import patch
1+
import pytest
32

43
from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
54
from mcp_server_qdrant.settings import (
@@ -18,34 +17,51 @@ def test_default_values(self):
1817
# Should not raise error because there are no required fields
1918
QdrantSettings()
2019

21-
@patch.dict(
22-
os.environ,
23-
{"QDRANT_URL": "http://localhost:6333", "COLLECTION_NAME": "test_collection"},
24-
)
25-
def test_minimal_config(self):
20+
def test_minimal_config(self, monkeypatch):
2621
"""Test loading minimal configuration from environment variables."""
22+
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
23+
monkeypatch.setenv("COLLECTION_NAME", "test_collection")
24+
2725
settings = QdrantSettings()
2826
assert settings.location == "http://localhost:6333"
2927
assert settings.collection_name == "test_collection"
3028
assert settings.api_key is None
3129
assert settings.local_path is None
3230

33-
@patch.dict(
34-
os.environ,
35-
{
36-
"QDRANT_URL": "http://qdrant.example.com:6333",
37-
"QDRANT_API_KEY": "test_api_key",
38-
"COLLECTION_NAME": "my_memories",
39-
"QDRANT_LOCAL_PATH": "/tmp/qdrant",
40-
},
41-
)
42-
def test_full_config(self):
31+
def test_full_config(self, monkeypatch):
4332
"""Test loading full configuration from environment variables."""
33+
monkeypatch.setenv("QDRANT_URL", "http://qdrant.example.com:6333")
34+
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
35+
monkeypatch.setenv("COLLECTION_NAME", "my_memories")
36+
monkeypatch.setenv("QDRANT_SEARCH_LIMIT", "15")
37+
monkeypatch.setenv("QDRANT_READ_ONLY", "1")
38+
4439
settings = QdrantSettings()
4540
assert settings.location == "http://qdrant.example.com:6333"
4641
assert settings.api_key == "test_api_key"
4742
assert settings.collection_name == "my_memories"
48-
assert settings.local_path == "/tmp/qdrant"
43+
assert settings.search_limit == 15
44+
assert settings.read_only is True
45+
46+
def test_local_path_config(self, monkeypatch):
47+
"""Test loading local path configuration from environment variables."""
48+
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")
49+
50+
settings = QdrantSettings()
51+
assert settings.local_path == "/path/to/local/qdrant"
52+
53+
def test_local_path_is_exclusive_with_url(self, monkeypatch):
54+
"""Test that local path cannot be set if Qdrant URL is provided."""
55+
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
56+
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")
57+
58+
with pytest.raises(ValueError):
59+
QdrantSettings()
60+
61+
monkeypatch.delenv("QDRANT_URL", raising=False)
62+
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
63+
with pytest.raises(ValueError):
64+
QdrantSettings()
4965

5066

5167
class TestEmbeddingProviderSettings:
@@ -55,12 +71,9 @@ def test_default_values(self):
5571
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
5672
assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2"
5773

58-
@patch.dict(
59-
os.environ,
60-
{"EMBEDDING_MODEL": "custom_model"},
61-
)
62-
def test_custom_values(self):
74+
def test_custom_values(self, monkeypatch):
6375
"""Test loading custom values from environment variables."""
76+
monkeypatch.setenv("EMBEDDING_MODEL", "custom_model")
6477
settings = EmbeddingProviderSettings()
6578
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
6679
assert settings.model_name == "custom_model"
@@ -73,35 +86,24 @@ def test_default_values(self):
7386
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
7487
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
7588

76-
@patch.dict(
77-
os.environ,
78-
{"TOOL_STORE_DESCRIPTION": "Custom store description"},
79-
)
80-
def test_custom_store_description(self):
89+
def test_custom_store_description(self, monkeypatch):
8190
"""Test loading custom store description from environment variable."""
91+
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
8292
settings = ToolSettings()
8393
assert settings.tool_store_description == "Custom store description"
8494
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION
8595

86-
@patch.dict(
87-
os.environ,
88-
{"TOOL_FIND_DESCRIPTION": "Custom find description"},
89-
)
90-
def test_custom_find_description(self):
96+
def test_custom_find_description(self, monkeypatch):
9197
"""Test loading custom find description from environment variable."""
98+
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
9299
settings = ToolSettings()
93100
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
94101
assert settings.tool_find_description == "Custom find description"
95102

96-
@patch.dict(
97-
os.environ,
98-
{
99-
"TOOL_STORE_DESCRIPTION": "Custom store description",
100-
"TOOL_FIND_DESCRIPTION": "Custom find description",
101-
},
102-
)
103-
def test_all_custom_values(self):
103+
def test_all_custom_values(self, monkeypatch):
104104
"""Test loading all custom values from environment variables."""
105+
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
106+
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
105107
settings = ToolSettings()
106108
assert settings.tool_store_description == "Custom store description"
107109
assert settings.tool_find_description == "Custom find description"

0 commit comments

Comments
 (0)