Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for hybrid search in Azure AI vector store #2408

Merged
merged 3 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/components/vectordbs/dbs/azure_ai_search.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ config = {
}
```

## Using hybrid search

```python
config = {
"vector_store": {
"provider": "azure_ai_search",
"config": {
"service_name": "ai-search-test",
"api_key": "*****",
"collection_name": "mem0",
"embedding_model_dims": 1536,
"hybrid_search": True,
"vector_filter_mode": "postFilter"
}
}
}
```

## Configuration Parameters

| Parameter | Description | Default Value | Options |
Expand All @@ -60,6 +78,8 @@ config = {
| `embedding_model_dims` | Dimensions of the embedding model | `1536` | Any integer value |
| `compression_type` | Type of vector compression to use | `none` | `none`, `scalar`, `binary` |
| `use_float16` | Store vectors in half precision (Edm.Half) | `False` | `True`, `False` |
| `vector_filter_mode` | Vector filter mode to use | `preFilter` | `postFilter`, `preFilter` |
| `hybrid_search` | Use hybrid search | `False` | `True`, `False` |

## Notes on Configuration Options

Expand All @@ -68,6 +88,10 @@ config = {
- `scalar`: Scalar quantization with reasonable balance of speed and accuracy
- `binary`: Binary quantization for maximum compression with some accuracy trade-off

- **vector_filter_mode**:
- `preFilter`: Applies filters before vector search (faster)
- `postFilter`: Applies filters after vector search (may provide better relevance)

- **use_float16**: Using half precision (float16) reduces storage requirements but may slightly impact accuracy. Useful for very large vector collections.

- **Filterable Fields**: The implementation automatically extracts `user_id`, `run_id`, and `agent_id` fields from payloads for filtering.
27 changes: 16 additions & 11 deletions mem0/configs/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,40 @@ class AzureAISearchConfig(BaseModel):
api_key: str = Field(None, description="API key for the Azure AI Search service")
embedding_model_dims: int = Field(None, description="Dimension of the embedding vector")
compression_type: Optional[str] = Field(
None,
description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
None, description="Type of vector compression to use. Options: 'scalar', 'binary', or None"
)
use_float16: bool = Field(
False,
description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)"
False,
description="Whether to store vectors in half precision (Edm.Half) instead of full precision (Edm.Single)",
)

hybrid_search: bool = Field(
False, description="Whether to use hybrid search. If True, vector_filter_mode must be 'preFilter'"
)
vector_filter_mode: Optional[str] = Field(
"preFilter", description="Mode for vector filtering. Options: 'preFilter', 'postFilter'"
)

@model_validator(mode="before")
@classmethod
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
allowed_fields = set(cls.model_fields.keys())
input_fields = set(values.keys())
extra_fields = input_fields - allowed_fields

# Check for use_compression to provide a helpful error
if "use_compression" in extra_fields:
raise ValueError(
"The parameter 'use_compression' is no longer supported. "
"Please use 'compression_type=\"scalar\"' instead of 'use_compression=True' "
"or 'compression_type=None' instead of 'use_compression=False'."
)

if extra_fields:
raise ValueError(
f"Extra fields not allowed: {', '.join(extra_fields)}. "
f"Please input only the following fields: {', '.join(allowed_fields)}"
)

# Validate compression_type values
if "compression_type" in values and values["compression_type"] is not None:
valid_types = ["scalar", "binary"]
Expand All @@ -45,9 +50,9 @@ def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
f"Invalid compression_type: {values['compression_type']}. "
f"Must be one of: {', '.join(valid_types)}, or None"
)

return values

model_config = {
"arbitrary_types_allowed": True,
}
}
3 changes: 1 addition & 2 deletions mem0/configs/vector_stores/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ class ElasticsearchConfig(BaseModel):
use_ssl: bool = Field(True, description="Use SSL for connection")
auto_create_index: bool = Field(True, description="Automatically create index during initialization")
custom_search_query: Optional[Callable[[List[float], int, Optional[Dict]], Dict]] = Field(
None,
description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
None, description="Custom search query function. Parameters: (query, limit, filters) -> Dict"
)

@model_validator(mode="before")
Expand Down
6 changes: 2 additions & 4 deletions mem0/configs/vector_stores/vertex_ai_vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@ class GoogleMatchingEngineConfig(BaseModel):
credentials_path: Optional[str] = Field(None, description="Path to service account credentials file")
vector_search_api_endpoint: Optional[str] = Field(None, description="Vector search API endpoint")

model_config = {
"extra": "forbid"
}
model_config = {"extra": "forbid"}

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -26,4 +24,4 @@ def __init__(self, **kwargs):
def model_post_init(self, _context) -> None:
"""Set collection_name to index_id if not provided"""
if self.collection_name is None:
self.collection_name = self.index_id
self.collection_name = self.index_id
14 changes: 9 additions & 5 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,14 @@ def _process_config(config_dict: Dict[str, Any]) -> Dict[str, Any]:
if "vector_store" not in config_dict and "embedder" in config_dict:
config_dict["vector_store"] = {}
config_dict["vector_store"]["config"] = {}
config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"]["embedding_dims"]
config_dict["vector_store"]["config"]["embedding_model_dims"] = config_dict["embedder"]["config"][
"embedding_dims"
]
try:
return config_dict
except ValidationError as e:
logger.error(f"Configuration validation error: {e}")
raise


def add(
self,
Expand Down Expand Up @@ -204,7 +205,8 @@ def _add_to_vector_store(self, messages, metadata, filters, infer):
messages_embeddings = self.embedding_model.embed(new_mem, "add")
new_message_embeddings[new_mem] = messages_embeddings
existing_memories = self.vector_store.search(
query=messages_embeddings,
query=new_mem,
vectors=messages_embeddings,
limit=5,
filters=filters,
)
Expand All @@ -222,7 +224,9 @@ def _add_to_vector_store(self, messages, metadata, filters, infer):
temp_uuid_mapping[str(idx)] = item["id"]
retrieved_old_memory[idx]["id"] = str(idx)

function_calling_prompt = get_update_memory_messages(retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt)
function_calling_prompt = get_update_memory_messages(
retrieved_old_memory, new_retrieved_facts, self.custom_update_memory_prompt
)

try:
new_memories_with_actions = self.llm.generate_response(
Expand Down Expand Up @@ -479,7 +483,7 @@ def search(self, query, user_id=None, agent_id=None, run_id=None, limit=100, fil

def _search_vector_store(self, query, filters, limit):
embeddings = self.embedding_model.embed(query, "search")
memories = self.vector_store.search(query=embeddings, limit=limit, filters=filters)
memories = self.vector_store.search(query=query, vectors=embeddings, limit=limit, filters=filters)

excluded_keys = {
"user_id",
Expand Down
71 changes: 36 additions & 35 deletions mem0/vector_stores/azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def __init__(
collection_name,
api_key,
embedding_model_dims,
compression_type: Optional[str] = None,
compression_type: Optional[str] = None,
use_float16: bool = False,
hybrid_search: bool = False,
vector_filter_mode: Optional[str] = None,
):
"""
Initialize the Azure AI Search vector store.
Expand All @@ -60,13 +62,17 @@ def __init__(
Allowed values are None (no quantization), "scalar", or "binary".
use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
(Note: This flag is preserved from the initial implementation per feedback.)
hybrid_search (bool): Whether to use hybrid search. Default is False.
vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
"""
self.index_name = collection_name
self.collection_name = collection_name
self.embedding_model_dims = embedding_model_dims
# If compression_type is None, treat it as "none".
self.compression_type = (compression_type or "none").lower()
self.compression_type = (compression_type or "none").lower()
self.use_float16 = use_float16
self.hybrid_search = hybrid_search
self.vector_filter_mode = vector_filter_mode

self.search_client = SearchClient(
endpoint=f"https://{service_name}.search.windows.net",
Expand Down Expand Up @@ -113,8 +119,6 @@ def create_col(self):
)
]
# If no compression is desired, compression_configurations remains empty.


fields = [
SimpleField(name="id", type=SearchFieldDataType.String, key=True),
SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
Expand All @@ -123,19 +127,19 @@ def create_col(self):
SearchField(
name="vector",
type=vector_type,
searchable=True,
searchable=True,
vector_search_dimensions=self.embedding_model_dims,
vector_search_profile_name="my-vector-config",
),
SimpleField(name="payload", type=SearchFieldDataType.String, searchable=True),
SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
]

vector_search = VectorSearch(
profiles=[
VectorSearchProfile(
name="my-vector-config",
algorithm_configuration_name="my-algorithms-config",
compression_name=compression_name if self.compression_type != "none" else None
compression_name=compression_name if self.compression_type != "none" else None,
)
],
algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
Expand Down Expand Up @@ -164,8 +168,7 @@ def insert(self, vectors, payloads=None, ids=None):
"""
logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
documents = [
self._generate_document(vector, payload, id)
for id, vector, payload in zip(ids, vectors, payloads)
self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
]
response = self.search_client.upload_documents(documents)
for doc in response:
Expand All @@ -189,12 +192,13 @@ def _build_filter_expression(self, filters):
filter_expression = " and ".join(filter_conditions)
return filter_expression

def search(self, query, limit=5, filters=None):
def search(self, query, vectors, limit=5, filters=None):
"""
Search for similar vectors.

Args:
query (List[float]): Query vector.
query (str): Query.
vectors (List[float]): Query vector.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Dict, optional): Filters to apply to the search. Defaults to None.

Expand All @@ -205,23 +209,28 @@ def search(self, query, limit=5, filters=None):
if filters:
filter_expression = self._build_filter_expression(filters)

vector_query = VectorizedQuery(
vector=query, k_nearest_neighbors=limit, fields="vector"
)
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit
)
vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
if self.hybrid_search:
search_results = self.search_client.search(
search_text=query,
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=self.vector_filter_mode,
search_fields=["payload"],
)
else:
search_results = self.search_client.search(
vector_queries=[vector_query],
filter=filter_expression,
top=limit,
vector_filter_mode=self.vector_filter_mode,
)

results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return results

def delete(self, vector_id):
Expand Down Expand Up @@ -275,9 +284,7 @@ def get(self, vector_id) -> OutputData:
result = self.search_client.get_document(key=vector_id)
except ResourceNotFoundError:
return None
return OutputData(
id=result["id"], score=None, payload=json.loads(result["payload"])
)
return OutputData(id=result["id"], score=None, payload=json.loads(result["payload"]))

def list_cols(self) -> List[str]:
"""
Expand Down Expand Up @@ -321,17 +328,11 @@ def list(self, filters=None, limit=100):
if filters:
filter_expression = self._build_filter_expression(filters)

search_results = self.search_client.search(
search_text="*", filter=filter_expression, top=limit
)
search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
results = []
for result in search_results:
payload = json.loads(result["payload"])
results.append(
OutputData(
id=result["id"], score=result["@search.score"], payload=payload
)
)
results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
return [results]

def __del__(self):
Expand Down
2 changes: 1 addition & 1 deletion mem0/vector_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def insert(self, vectors, payloads=None, ids=None):
pass

@abstractmethod
def search(self, query, limit=5, filters=None):
def search(self, query, vectors, limit=5, filters=None):
"""Search for similar vectors."""
pass

Expand Down
9 changes: 6 additions & 3 deletions mem0/vector_stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,19 +127,22 @@ def insert(
logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}")
self.collection.add(ids=ids, embeddings=vectors, metadatas=payloads)

def search(self, query: List[list], limit: int = 5, filters: Optional[Dict] = None) -> List[OutputData]:
def search(
self, query: str, vectors: List[list], limit: int = 5, filters: Optional[Dict] = None
) -> List[OutputData]:
"""
Search for similar vectors.

Args:
query (List[list]): Query vector.
query (str): Query.
vectors (List[list]): List of vectors to search.
limit (int, optional): Number of results to return. Defaults to 5.
filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.

Returns:
List[OutputData]: Search results.
"""
results = self.collection.query(query_embeddings=query, where=filters, n_results=limit)
results = self.collection.query(query_embeddings=vectors, where=filters, n_results=limit)
final_results = self._parse_output(results)
return final_results

Expand Down
Loading
Loading