Skip to content

new: update type hints #64

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/mcp_server_qdrant/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from abc import ABC, abstractmethod
from typing import List


class EmbeddingProvider(ABC):
"""Abstract base class for embedding providers."""

@abstractmethod
async def embed_documents(self, documents: List[str]) -> List[List[float]]:
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
"""Embed a list of documents into vectors."""
pass

@abstractmethod
async def embed_query(self, query: str) -> List[float]:
async def embed_query(self, query: str) -> list[float]:
"""Embed a query into a vector."""
pass

Expand Down
5 changes: 2 additions & 3 deletions src/mcp_server_qdrant/embeddings/fastembed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from typing import List

from fastembed import TextEmbedding
from fastembed.common.model_description import DenseModelDescription
Expand All @@ -17,7 +16,7 @@ def __init__(self, model_name: str):
self.model_name = model_name
self.embedding_model = TextEmbedding(model_name)

async def embed_documents(self, documents: List[str]) -> List[List[float]]:
async def embed_documents(self, documents: list[str]) -> list[list[float]]:
"""Embed a list of documents into vectors."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()
Expand All @@ -26,7 +25,7 @@ async def embed_documents(self, documents: List[str]) -> List[List[float]]:
)
return [embedding.tolist() for embedding in embeddings]

async def embed_query(self, query: str) -> List[float]:
async def embed_query(self, query: str) -> list[float]:
"""Embed a query into a vector."""
# Run in a thread pool since FastEmbed is synchronous
loop = asyncio.get_event_loop()
Expand Down
13 changes: 5 additions & 8 deletions src/mcp_server_qdrant/mcp_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Annotated, Any, List, Optional
from typing import Annotated, Any

from fastmcp import Context, FastMCP
from pydantic import Field
Expand Down Expand Up @@ -76,7 +76,7 @@ async def store(
# If we set it to be optional, some of the MCP clients, like Cursor, cannot
# handle the optional parameter correctly.
metadata: Annotated[
Optional[Metadata],
Metadata | None,
Field(
description="Extra metadata stored along with memorised information. Any json is accepted."
),
Expand Down Expand Up @@ -106,14 +106,15 @@ async def find(
collection_name: Annotated[
str, Field(description="The collection to search in")
],
query_filter: Optional[ArbitraryFilter] = None,
) -> List[str]:
query_filter: ArbitraryFilter | None = None,
) -> list[str]:
"""
Find memories in Qdrant.
:param ctx: The context for the request.
:param query: The query to use for the search.
:param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used.
:param query_filter: The filter to apply to the query.
:return: A list of entries found.
"""

Expand All @@ -123,10 +124,6 @@ async def find(
query_filter = models.Filter(**query_filter) if query_filter else None

await ctx.debug(f"Finding results for query {query}")
if collection_name:
await ctx.debug(
f"Overriding the collection name with {collection_name}"
)

entries = await self.qdrant_connector.search(
query,
Expand Down
27 changes: 14 additions & 13 deletions src/mcp_server_qdrant/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import uuid
from typing import Any, Dict, Optional
from typing import Any

from pydantic import BaseModel
from qdrant_client import AsyncQdrantClient, models
Expand All @@ -10,9 +10,8 @@

logger = logging.getLogger(__name__)

Metadata = Dict[str, Any]

ArbitraryFilter = Dict[str, Any]
Metadata = dict[str, Any]
ArbitraryFilter = dict[str, Any]


class Entry(BaseModel):
Expand All @@ -21,7 +20,7 @@ class Entry(BaseModel):
"""

content: str
metadata: Optional[Metadata] = None
metadata: Metadata | None = None


class QdrantConnector:
Expand All @@ -37,12 +36,12 @@ class QdrantConnector:

def __init__(
self,
qdrant_url: Optional[str],
qdrant_api_key: Optional[str],
collection_name: Optional[str],
qdrant_url: str | None,
qdrant_api_key: str | None,
collection_name: str | None,
embedding_provider: EmbeddingProvider,
qdrant_local_path: Optional[str] = None,
field_indexes: Optional[dict[str, models.PayloadSchemaType]] = None,
qdrant_local_path: str | None = None,
field_indexes: dict[str, models.PayloadSchemaType] | None = None,
):
self._qdrant_url = qdrant_url.rstrip("/") if qdrant_url else None
self._qdrant_api_key = qdrant_api_key
Expand All @@ -61,7 +60,7 @@ async def get_collection_names(self) -> list[str]:
response = await self._client.get_collections()
return [collection.name for collection in response.collections]

async def store(self, entry: Entry, *, collection_name: Optional[str] = None):
async def store(self, entry: Entry, *, collection_name: str | None = None):
"""
Store some information in the Qdrant collection, along with the specified metadata.
:param entry: The entry to store in the Qdrant collection.
Expand Down Expand Up @@ -95,16 +94,18 @@ async def search(
self,
query: str,
*,
collection_name: Optional[str] = None,
collection_name: str | None = None,
limit: int = 10,
query_filter: Optional[models.Filter] = None,
query_filter: models.Filter | None = None,
) -> list[Entry]:
"""
Find points in the Qdrant collection. If there are no entries found, an empty list is returned.
:param query: The query to use for the search.
:param collection_name: The name of the collection to search in, optional. If not provided,
the default collection is used.
:param limit: The maximum number of entries to return.
:param query_filter: The filter to apply to the query, if any.

:return: A list of entries found.
"""
collection_name = collection_name or self._default_collection_name
Expand Down
27 changes: 17 additions & 10 deletions src/mcp_server_qdrant/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal, Optional
from typing import Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from pydantic_settings import BaseSettings

from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
Expand Down Expand Up @@ -56,7 +56,7 @@ class FilterableField(BaseModel):
field_type: Literal["keyword", "integer", "float", "boolean"] = Field(
description="The type of the field"
)
condition: Optional[Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"]] = (
condition: Literal["==", "!=", ">", ">=", "<", "<=", "any", "except"] | None = (
Field(
default=None,
description=(
Expand All @@ -76,18 +76,16 @@ class QdrantSettings(BaseSettings):
Configuration for the Qdrant connector.
"""

location: Optional[str] = Field(default=None, validation_alias="QDRANT_URL")
api_key: Optional[str] = Field(default=None, validation_alias="QDRANT_API_KEY")
collection_name: Optional[str] = Field(
location: str | None = Field(default=None, validation_alias="QDRANT_URL")
api_key: str | None = Field(default=None, validation_alias="QDRANT_API_KEY")
collection_name: str | None = Field(
default=None, validation_alias="COLLECTION_NAME"
)
local_path: Optional[str] = Field(
default=None, validation_alias="QDRANT_LOCAL_PATH"
)
local_path: str | None = Field(default=None, validation_alias="QDRANT_LOCAL_PATH")
search_limit: int = Field(default=10, validation_alias="QDRANT_SEARCH_LIMIT")
read_only: bool = Field(default=False, validation_alias="QDRANT_READ_ONLY")

filterable_fields: Optional[list[FilterableField]] = Field(default=None)
filterable_fields: list[FilterableField] | None = Field(default=None)

allow_arbitrary_filter: bool = Field(
default=False, validation_alias="QDRANT_ALLOW_ARBITRARY_FILTER"
Expand All @@ -106,3 +104,12 @@ def filterable_fields_dict_with_conditions(self) -> dict[str, FilterableField]:
for field in self.filterable_fields
if field.condition is not None
}

@model_validator(mode="after")
def check_local_path_conflict(self) -> "QdrantSettings":
if self.local_path:
if self.location is not None or self.api_key is not None:
raise ValueError(
"If 'local_path' is set, 'location' and 'api_key' must be None."
)
return self
84 changes: 43 additions & 41 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from unittest.mock import patch
import pytest

from mcp_server_qdrant.embeddings.types import EmbeddingProviderType
from mcp_server_qdrant.settings import (
Expand All @@ -18,34 +17,51 @@ def test_default_values(self):
# Should not raise error because there are no required fields
QdrantSettings()

@patch.dict(
os.environ,
{"QDRANT_URL": "http://localhost:6333", "COLLECTION_NAME": "test_collection"},
)
def test_minimal_config(self):
def test_minimal_config(self, monkeypatch):
"""Test loading minimal configuration from environment variables."""
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
monkeypatch.setenv("COLLECTION_NAME", "test_collection")

settings = QdrantSettings()
assert settings.location == "http://localhost:6333"
assert settings.collection_name == "test_collection"
assert settings.api_key is None
assert settings.local_path is None

@patch.dict(
os.environ,
{
"QDRANT_URL": "http://qdrant.example.com:6333",
"QDRANT_API_KEY": "test_api_key",
"COLLECTION_NAME": "my_memories",
"QDRANT_LOCAL_PATH": "/tmp/qdrant",
},
)
def test_full_config(self):
def test_full_config(self, monkeypatch):
"""Test loading full configuration from environment variables."""
monkeypatch.setenv("QDRANT_URL", "http://qdrant.example.com:6333")
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
monkeypatch.setenv("COLLECTION_NAME", "my_memories")
monkeypatch.setenv("QDRANT_SEARCH_LIMIT", "15")
monkeypatch.setenv("QDRANT_READ_ONLY", "1")

settings = QdrantSettings()
assert settings.location == "http://qdrant.example.com:6333"
assert settings.api_key == "test_api_key"
assert settings.collection_name == "my_memories"
assert settings.local_path == "/tmp/qdrant"
assert settings.search_limit == 15
assert settings.read_only is True

def test_local_path_config(self, monkeypatch):
"""Test loading local path configuration from environment variables."""
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")

settings = QdrantSettings()
assert settings.local_path == "/path/to/local/qdrant"

def test_local_path_is_exclusive_with_url(self, monkeypatch):
"""Test that local path cannot be set if Qdrant URL is provided."""
monkeypatch.setenv("QDRANT_URL", "http://localhost:6333")
monkeypatch.setenv("QDRANT_LOCAL_PATH", "/path/to/local/qdrant")

with pytest.raises(ValueError):
QdrantSettings()

monkeypatch.delenv("QDRANT_URL", raising=False)
monkeypatch.setenv("QDRANT_API_KEY", "test_api_key")
with pytest.raises(ValueError):
QdrantSettings()


class TestEmbeddingProviderSettings:
Expand All @@ -55,12 +71,9 @@ def test_default_values(self):
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
assert settings.model_name == "sentence-transformers/all-MiniLM-L6-v2"

@patch.dict(
os.environ,
{"EMBEDDING_MODEL": "custom_model"},
)
def test_custom_values(self):
def test_custom_values(self, monkeypatch):
"""Test loading custom values from environment variables."""
monkeypatch.setenv("EMBEDDING_MODEL", "custom_model")
settings = EmbeddingProviderSettings()
assert settings.provider_type == EmbeddingProviderType.FASTEMBED
assert settings.model_name == "custom_model"
Expand All @@ -73,35 +86,24 @@ def test_default_values(self):
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION

@patch.dict(
os.environ,
{"TOOL_STORE_DESCRIPTION": "Custom store description"},
)
def test_custom_store_description(self):
def test_custom_store_description(self, monkeypatch):
"""Test loading custom store description from environment variable."""
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
settings = ToolSettings()
assert settings.tool_store_description == "Custom store description"
assert settings.tool_find_description == DEFAULT_TOOL_FIND_DESCRIPTION

@patch.dict(
os.environ,
{"TOOL_FIND_DESCRIPTION": "Custom find description"},
)
def test_custom_find_description(self):
def test_custom_find_description(self, monkeypatch):
"""Test loading custom find description from environment variable."""
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
settings = ToolSettings()
assert settings.tool_store_description == DEFAULT_TOOL_STORE_DESCRIPTION
assert settings.tool_find_description == "Custom find description"

@patch.dict(
os.environ,
{
"TOOL_STORE_DESCRIPTION": "Custom store description",
"TOOL_FIND_DESCRIPTION": "Custom find description",
},
)
def test_all_custom_values(self):
def test_all_custom_values(self, monkeypatch):
"""Test loading all custom values from environment variables."""
monkeypatch.setenv("TOOL_STORE_DESCRIPTION", "Custom store description")
monkeypatch.setenv("TOOL_FIND_DESCRIPTION", "Custom find description")
settings = ToolSettings()
assert settings.tool_store_description == "Custom store description"
assert settings.tool_find_description == "Custom find description"