diff --git a/src/mcp_server_qdrant/embeddings/base.py b/src/mcp_server_qdrant/embeddings/base.py index 80c1d133..5c47a175 100644 --- a/src/mcp_server_qdrant/embeddings/base.py +++ b/src/mcp_server_qdrant/embeddings/base.py @@ -1,17 +1,16 @@ from abc import ABC, abstractmethod -from typing import List class EmbeddingProvider(ABC): """Abstract base class for embedding providers.""" @abstractmethod - async def embed_documents(self, documents: List[str]) -> List[List[float]]: + async def embed_documents(self, documents: list[str]) -> list[list[float]]: """Embed a list of documents into vectors.""" pass @abstractmethod - async def embed_query(self, query: str) -> List[float]: + async def embed_query(self, query: str) -> list[float]: """Embed a query into a vector.""" pass diff --git a/src/mcp_server_qdrant/embeddings/fastembed.py b/src/mcp_server_qdrant/embeddings/fastembed.py index 628fe148..1655f835 100644 --- a/src/mcp_server_qdrant/embeddings/fastembed.py +++ b/src/mcp_server_qdrant/embeddings/fastembed.py @@ -1,5 +1,4 @@ import asyncio -from typing import List from fastembed import TextEmbedding from fastembed.common.model_description import DenseModelDescription @@ -17,7 +16,7 @@ def __init__(self, model_name: str): self.model_name = model_name self.embedding_model = TextEmbedding(model_name) - async def embed_documents(self, documents: List[str]) -> List[List[float]]: + async def embed_documents(self, documents: list[str]) -> list[list[float]]: """Embed a list of documents into vectors.""" # Run in a thread pool since FastEmbed is synchronous loop = asyncio.get_event_loop() @@ -26,7 +25,7 @@ async def embed_documents(self, documents: List[str]) -> List[List[float]]: ) return [embedding.tolist() for embedding in embeddings] - async def embed_query(self, query: str) -> List[float]: + async def embed_query(self, query: str) -> list[float]: """Embed a query into a vector.""" # Run in a thread pool since FastEmbed is synchronous loop = asyncio.get_event_loop() diff --git a/src/mcp_server_qdrant/mcp_server.py b/src/mcp_server_qdrant/mcp_server.py index 5985f7ad..63607135 100644 --- a/src/mcp_server_qdrant/mcp_server.py +++ b/src/mcp_server_qdrant/mcp_server.py @@ -1,6 +1,6 @@ import json import logging -from typing import Annotated, Any, List, Optional +from typing import Annotated, Any from fastmcp import Context, FastMCP from pydantic import Field @@ -76,7 +76,7 @@ async def store( # If we set it to be optional, some of the MCP clients, like Cursor, cannot # handle the optional parameter correctly. metadata: Annotated[ - Optional[Metadata], + Metadata | None, Field( description="Extra metadata stored along with memorised information. Any json is accepted." ), @@ -106,14 +106,15 @@ async def find( collection_name: Annotated[ str, Field(description="The collection to search in") ], - query_filter: Optional[ArbitraryFilter] = None, - ) -> List[str]: + query_filter: ArbitraryFilter | None = None, + ) -> list[str]: """ Find memories in Qdrant. :param ctx: The context for the request. :param query: The query to use for the search. :param collection_name: The name of the collection to search in, optional. If not provided, the default collection is used. + :param query_filter: The filter to apply to the query. :return: A list of entries found. """ @@ -123,10 +124,6 @@ async def find( query_filter = models.Filter(**query_filter) if query_filter else None await ctx.debug(f"Finding results for query {query}") - if collection_name: - await ctx.debug( - f"Overriding the collection name with {collection_name}" - ) entries = await self.qdrant_connector.search( query, diff --git a/src/mcp_server_qdrant/qdrant.py b/src/mcp_server_qdrant/qdrant.py index a7656f86..8d3e5aa8 100644 --- a/src/mcp_server_qdrant/qdrant.py +++ b/src/mcp_server_qdrant/qdrant.py @@ -1,6 +1,6 @@ import logging import uuid -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel from qdrant_client import AsyncQdrantClient, models @@ -10,9 +10,8 @@ logger = logging.getLogger(__name__) -Metadata = Dict[str, Any] - -ArbitraryFilter = Dict[str, Any] +Metadata = dict[str, Any] +ArbitraryFilter = dict[str, Any] class Entry(BaseModel): @@ -21,7 +20,7 @@ class Entry(BaseModel): """ content: str - metadata: Optional[Metadata] = None + metadata: Metadata | None = None class QdrantConnector: @@ -37,12 +36,12 @@ class QdrantConnector: def __init__( self, - qdrant_url: Optional[str], - qdrant_api_key: Optional[str], - collection_name: Optional[str], + qdrant_url: str | None, + qdrant_api_key: str | None, + collection_name: str | None, embedding_provider: EmbeddingProvider, - qdrant_local_path: Optional[str] = None, - field_indexes: Optional[dict[str, models.PayloadSchemaType]] = None, + qdrant_local_path: str | None = None, + field_indexes: dict[str, models.PayloadSchemaType] | None = None, ): self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None self._qdrant_api_key = qdrant_api_key @@ -61,7 +60,7 @@ async def get_collection_names(self) -> list[str]: response = await self._client.get_collections() return [collection.name for collection in response.collections] - async def store(self, entry: Entry, *, collection_name: Optional[str] = None): + async def store(self, entry: Entry, *, collection_name: str | None = None): """ Store some information in the Qdrant collection, along with the specified metadata. :param entry: The entry to store in the Qdrant collection. @@ -95,9 +94,9 @@ async def search( self, query: str, *, - collection_name: Optional[str] = None, + collection_name: str | None = None, limit: int = 10, - query_filter: Optional[models.Filter] = None, + query_filter: models.Filter | None = None, ) -> list[Entry]: """ Find points in the Qdrant collection. If there are no entries found, an empty list is returned. @@ -105,6 +104,8 @@ async def search( :param collection_name: The name of the collection to search in, optional. If not provided, the default collection is used. :param limit: The maximum number of entries to return. + :param query_filter: The filter to apply to the query, if any. + :return: A list of entries found. """ collection_name = collection_name or self._default_collection_name diff --git a/src/mcp_server_qdrant/settings.py b/src/mcp_server_qdrant/settings.py index f41e400a..e48c10d1 100644 --- a/src/mcp_server_qdrant/settings.py +++ b/src/mcp_server_qdrant/settings.py @@ -1,6 +1,6 @@ -from typing import Literal, Optional +from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from pydantic_settings import BaseSettings from mcp_server_qdrant.embeddings.types import EmbeddingProviderType @@ -56,7 +56,7 @@ class FilterableField(BaseModel): field_type: Literal["keyword", "integer", "float", "boolean"] = Field( description="The type of the field" ) - condition: Optional[Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"]] = ( + condition: Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"] | None = ( Field( default=None, description=( @@ -76,18 +76,16 @@ class QdrantSettings(BaseSettings): Configuration for the Qdrant connector. """ - location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL") - api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY") - collection_name: Optional[str] = Field( + location: str | None = Field(default=None, validation_alias="QDRANT_URL") + api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY") + collection_name: str | None = Field( default=None, validation_alias="COLLECTION_NAME" ) - local_path: Optional[str] = Field( - default=None, validation_alias="QDRANT_LOCAL_PATH" - ) + local_path: str | None = Field(default=None, validation_alias="QDRANT_LOCAL_PATH") search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT") read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY") - filterable_fields: Optional[list[FilterableField]] = Field(default=None) + filterable_fields: list[FilterableField] | None = Field(default=None) allow_arbitrary_filter: bool = Field( default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER" @@ -106,3 +104,12 @@ def filterable_fields_dict_with_conditions(self) -> dict[str, FilterableField]: for field in self.filterable_fields if field.condition is not None } + + @model_validator(mode="after") + def check_local_path_conflict(self) -> "QdrantSettings": + if self.local_path: + if self.location is not None or self.api_key is not None: + raise ValueError( + "If 'local_path' is set, 'location' and 'api_key' must be None." + ) + return self diff --git a/tests/test_settings.py b/tests/test_settings.py index d5b2d37e..b5c5af9d 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,5 +1,4 @@ -import os -from unittest.mock import patch +import pytest from mcp_server_qdrant.embeddings.types import EmbeddingProviderType from mcp_server_qdrant.settings import ( @@ -18,34 +17,51 @@ def test_default_values(self): # Should not raise error because there are no required fields QdrantSettings() - @patch.dict( - os.environ, - {"QDRANT_URL": "http://localhost:6333", "COLLECTION_NAME": "test_collection"}, - ) - def test_minimal_config(self): + def test_minimal_config(self, monkeypatch): """Test loading minimal configuration from environment variables.""" + monkeypatch.setenv("QDRANT_URL", "http://localhost:6333") + monkeypatch.setenv("COLLECTION_NAME", "test_collection") + settings = QdrantSettings() assert settings.location == "http://localhost:6333" assert settings.collection_name == "test_collection" assert settings.api_key is None assert settings.local_path is None - @patch.dict( - os.environ, - { - "QDRANT_URL": "http://qdrant.example.com:6333", - "QDRANT_API_KEY": "test_api_key", - "COLLECTION_NAME": "my_memories", - "QDRANT_LOCAL_PATH": "/tmp/qdrant", - }, - ) - def test_full_config(self): + def test_full_config(self, monkeypatch): """Test loading full configuration from environment variables.""" + monkeypatch.setenv("QDRANT_URL", "http://qdrant.example.com:6333") + monkeypatch.setenv("QDRANT_API_KEY", "test_api_key") + monkeypatch.setenv("COLLECTION_NAME", "my_memories") + monkeypatch.setenv("QDRANT_SEARCH_LIMIT", "15") + monkeypatch.setenv("QDRANT_READ_ONLY", "1") + settings = QdrantSettings() assert settings.location == "http://qdrant.example.com:6333" assert settings.api_key == "test_api_key" assert settings.collection_name == "my_memories" - assert settings.local_path == "/tmp/qdrant" + assert settings.search_limit == 15 + assert settings.read_only is True + + def test_local_path_config(self, monkeypatch): + """Test loading local path configuration from environment variables.""" + monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant") + + settings = QdrantSettings() + assert settings.local_path == "/path/to/local/qdrant" + + def test_local_path_is_exclusive_with_url(self, monkeypatch): + """Test that local path cannot be set if Qdrant URL is provided.""" + monkeypatch.setenv("QDRANT_URL", "http://localhost:6333") + monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant") + + with pytest.raises(ValueError): + QdrantSettings() + + monkeypatch.delenv("QDRANT_URL", raising=False) + monkeypatch.setenv("QDRANT_API_KEY", "test_api_key") + with pytest.raises(ValueError): + QdrantSettings() class TestEmbeddingProviderSettings: @@ -55,12 +71,9 @@ def test_default_values(self): assert settings.provider_type == EmbeddingProviderType.FASTEMBED assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2" - @patch.dict( - os.environ, - {"EMBEDDING_MODEL": "custom_model"}, - ) - def test_custom_values(self): + def test_custom_values(self, monkeypatch): """Test loading custom values from environment variables.""" + monkeypatch.setenv("EMBEDDING_MODEL", "custom_model") settings = EmbeddingProviderSettings() assert settings.provider_type == EmbeddingProviderType.FASTEMBED assert settings.model_name == "custom_model" @@ -73,35 +86,24 @@ def test_default_values(self): assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION - @patch.dict( - os.environ, - {"TOOL_STORE_DESCRIPTION": "Custom store description"}, - ) - def test_custom_store_description(self): + def test_custom_store_description(self, monkeypatch): """Test loading custom store description from environment variable.""" + monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description") settings = ToolSettings() assert settings.tool_store_description == "Custom store description" assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION - @patch.dict( - os.environ, - {"TOOL_FIND_DESCRIPTION": "Custom find description"}, - ) - def test_custom_find_description(self): + def test_custom_find_description(self, monkeypatch): """Test loading custom find description from environment variable.""" + monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description") settings = ToolSettings() assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION assert settings.tool_find_description == "Custom find description" - @patch.dict( - os.environ, - { - "TOOL_STORE_DESCRIPTION": "Custom store description", - "TOOL_FIND_DESCRIPTION": "Custom find description", - }, - ) - def test_all_custom_values(self): + def test_all_custom_values(self, monkeypatch): """Test loading all custom values from environment variables.""" + monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description") + monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description") settings = ToolSettings() assert settings.tool_store_description == "Custom store description" assert settings.tool_find_description == "Custom find description"