Skip to content

Commit 67e08d0

Browse files
authored
Enable kwargs in SearchIndex and Embedding Retriever (#1185)
* Enable kwargs for semantic ranking
1 parent e21ce0c commit 67e08d0

File tree

7 files changed

+78
-49
lines changed

7 files changed

+78
-49
lines changed

integrations/azure_ai_search/example/document_store.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from haystack import Document
2-
from haystack.document_stores.types import DuplicatePolicy
32

43
from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore
54

@@ -30,7 +29,7 @@
3029
meta={"version": 2.0, "label": "chapter_three"},
3130
),
3231
]
33-
document_store.write_documents(documents, policy=DuplicatePolicy.SKIP)
32+
document_store.write_documents(documents)
3433

3534
filters = {
3635
"operator": "AND",

integrations/azure_ai_search/example/embedding_retrieval.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from haystack import Document, Pipeline
22
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
33
from haystack.components.writers import DocumentWriter
4-
from haystack.document_stores.types import DuplicatePolicy
54

65
from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever
76
from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore
@@ -38,9 +37,7 @@
3837
# Indexing Pipeline
3938
indexing_pipeline = Pipeline()
4039
indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder")
41-
indexing_pipeline.add_component(
42-
instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer"
43-
)
40+
indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="doc_writer")
4441
indexing_pipeline.connect("doc_embedder", "doc_writer")
4542

4643
indexing_pipeline.run({"doc_embedder": {"documents": documents}})

integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py

+26-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from haystack.document_stores.types import FilterPolicy
66
from haystack.document_stores.types.filter_policy import apply_filter_policy
77

8-
from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters
8+
from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters
99

1010
logger = logging.getLogger(__name__)
1111

@@ -25,16 +25,23 @@ def __init__(
2525
filters: Optional[Dict[str, Any]] = None,
2626
top_k: int = 10,
2727
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
28+
**kwargs,
2829
):
2930
"""
3031
Create the AzureAISearchEmbeddingRetriever component.
3132
3233
:param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever.
3334
:param filters: Filters applied when fetching documents from the Document Store.
34-
Filters are applied during the approximate kNN search to ensure the Retriever returns
35-
`top_k` matching documents.
3635
:param top_k: Maximum number of documents to return.
37-
:filter_policy: Policy to determine how filters are applied. Possible options:
36+
:param filter_policy: Policy to determine how filters are applied.
37+
:param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint.
38+
Some of the supported parameters:
39+
- `query_type`: A string indicating the type of query to perform. Possible values are
40+
'simple','full' and 'semantic'.
41+
- `semantic_configuration_name`: The name of semantic configuration to be used when
42+
processing semantic queries.
43+
For more information on parameters, see the
44+
[official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
3845
3946
"""
4047
self._filters = filters or {}
@@ -43,6 +50,7 @@ def __init__(
4350
self._filter_policy = (
4451
filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy)
4552
)
53+
self._kwargs = kwargs
4654

4755
if not isinstance(document_store, AzureAISearchDocumentStore):
4856
message = "document_store must be an instance of AzureAISearchDocumentStore"
@@ -61,6 +69,7 @@ def to_dict(self) -> Dict[str, Any]:
6169
top_k=self._top_k,
6270
document_store=self._document_store.to_dict(),
6371
filter_policy=self._filter_policy.value,
72+
**self._kwargs,
6473
)
6574

6675
@classmethod
@@ -88,29 +97,31 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever":
8897
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
8998
"""Retrieve documents from the AzureAISearchDocumentStore.
9099
91-
:param query_embedding: floats representing the query embedding
100+
:param query_embedding: A list of floats representing the query embedding.
92101
:param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on
93-
the `filter_policy` chosen at retriever initialization. See init method docstring for more
94-
details.
95-
:param top_k: the maximum number of documents to retrieve.
96-
:returns: a dictionary with the following keys:
97-
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
102+
the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more
103+
details.
104+
:param top_k: The maximum number of documents to retrieve.
105+
:returns: Dictionary with the following keys:
106+
- `documents`: A list of documents retrieved from the AzureAISearchDocumentStore.
98107
"""
99108

100109
top_k = top_k or self._top_k
101110
if filters is not None:
102111
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
103-
normalized_filters = normalize_filters(applied_filters)
112+
normalized_filters = _normalize_filters(applied_filters)
104113
else:
105114
normalized_filters = ""
106115

107116
try:
108117
docs = self._document_store._embedding_retrieval(
109-
query_embedding=query_embedding,
110-
filters=normalized_filters,
111-
top_k=top_k,
118+
query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs
112119
)
113120
except Exception as e:
114-
raise e
121+
msg = (
122+
"An error occurred during the embedding retrieval process from the AzureAISearchDocumentStore. "
123+
"Ensure that the query embedding is valid and the document store is correctly configured."
124+
)
125+
raise RuntimeError(msg) from e
115126

116127
return {"documents": docs}

integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore
5-
from .filters import normalize_filters
5+
from .filters import _normalize_filters
66

7-
__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"]
7+
__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "_normalize_filters"]

integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from haystack.utils import Secret, deserialize_secrets_inplace
3232

3333
from .errors import AzureAISearchDocumentStoreConfigError
34-
from .filters import normalize_filters
34+
from .filters import _normalize_filters
3535

3636
type_mapping = {
3737
str: "Edm.String",
@@ -70,7 +70,7 @@ def __init__(
7070
embedding_dimension: int = 768,
7171
metadata_fields: Optional[Dict[str, type]] = None,
7272
vector_search_configuration: VectorSearch = None,
73-
**kwargs,
73+
**index_creation_kwargs,
7474
):
7575
"""
7676
A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/)
@@ -87,19 +87,20 @@ def __init__(
8787
:param vector_search_configuration: Configuration option related to vector search.
8888
Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches.
8989
90-
:param kwargs: Optional keyword parameters for Azure AI Search.
91-
Some of the supported parameters:
92-
- `api_version`: The Search API version to use for requests.
93-
- `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD).
94-
The audience is not considered when using a shared key. If audience is not provided,
95-
the public cloud audience will be assumed.
90+
:param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class
91+
during index creation. Some of the supported parameters:
92+
- `semantic_search`: Defines semantic configuration of the search index. This parameter is needed
93+
to enable semantic search capabilities in index.
94+
- `similarity`: The type of similarity algorithm to be used when scoring and ranking the documents
95+
matching a search query. The similarity algorithm can only be defined at index creation time and
96+
cannot be modified on existing indexes.
9697
97-
For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/)
98+
For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
9899
"""
99100

100101
azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None
101102
if not azure_endpoint:
102-
msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT."
103+
msg = "Please provide an Azure endpoint or set the environment variable AZURE_SEARCH_SERVICE_ENDPOINT."
103104
raise ValueError(msg)
104105

105106
api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None
@@ -114,7 +115,7 @@ def __init__(
114115
self._dummy_vector = [-10.0] * self._embedding_dimension
115116
self._metadata_fields = metadata_fields
116117
self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH
117-
self._kwargs = kwargs
118+
self._index_creation_kwargs = index_creation_kwargs
118119

119120
@property
120121
def client(self) -> SearchClient:
@@ -128,7 +129,10 @@ def client(self) -> SearchClient:
128129
credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential()
129130
try:
130131
if not self._index_client:
131-
self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs)
132+
self._index_client = SearchIndexClient(
133+
resolved_endpoint,
134+
credential,
135+
)
132136
if not self._index_exists(self._index_name):
133137
# Create a new index if it does not exist
134138
logger.debug(
@@ -151,7 +155,7 @@ def client(self) -> SearchClient:
151155

152156
return self._client
153157

154-
def _create_index(self, index_name: str, **kwargs) -> None:
158+
def _create_index(self, index_name: str) -> None:
155159
"""
156160
Creates a new search index.
157161
:param index_name: Name of the index to create. If None, the index name from the constructor is used.
@@ -177,7 +181,10 @@ def _create_index(self, index_name: str, **kwargs) -> None:
177181
if self._metadata_fields:
178182
default_fields.extend(self._create_metadata_index_fields(self._metadata_fields))
179183
index = SearchIndex(
180-
name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs
184+
name=index_name,
185+
fields=default_fields,
186+
vector_search=self._vector_search_configuration,
187+
**self._index_creation_kwargs,
181188
)
182189
if self._index_client:
183190
self._index_client.create_index(index)
@@ -194,13 +201,13 @@ def to_dict(self) -> Dict[str, Any]:
194201
"""
195202
return default_to_dict(
196203
self,
197-
azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None,
198-
api_key=self._api_key.to_dict() if self._api_key is not None else None,
204+
azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None,
205+
api_key=self._api_key.to_dict() if self._api_key else None,
199206
index_name=self._index_name,
200207
embedding_dimension=self._embedding_dimension,
201208
metadata_fields=self._metadata_fields,
202209
vector_search_configuration=self._vector_search_configuration.as_dict(),
203-
**self._kwargs,
210+
**self._index_creation_kwargs,
204211
)
205212

206213
@classmethod
@@ -298,7 +305,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
298305
:returns: A list of Documents that match the given filters.
299306
"""
300307
if filters:
301-
normalized_filters = normalize_filters(filters)
308+
normalized_filters = _normalize_filters(filters)
302309
result = self.client.search(filter=normalized_filters)
303310
return self._convert_search_result_to_documents(result)
304311
else:
@@ -409,8 +416,8 @@ def _embedding_retrieval(
409416
query_embedding: List[float],
410417
*,
411418
top_k: int = 10,
412-
fields: Optional[List[str]] = None,
413419
filters: Optional[Dict[str, Any]] = None,
420+
**kwargs,
414421
) -> List[Document]:
415422
"""
416423
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
@@ -422,9 +429,10 @@ def _embedding_retrieval(
422429
`AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it.
423430
424431
:param query_embedding: Embedding of the query.
432+
:param top_k: Maximum number of Documents to return, defaults to 10.
425433
:param filters: Filters applied to the retrieved Documents. Defaults to None.
426434
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
427-
:param top_k: Maximum number of Documents to return, defaults to 10
435+
:param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint.
428436
429437
:raises ValueError: If `query_embedding` is an empty list
430438
:returns: List of Document that are most similar to `query_embedding`
@@ -435,6 +443,6 @@ def _embedding_retrieval(
435443
raise ValueError(msg)
436444

437445
vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding")
438-
result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters)
446+
result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs)
439447
azure_docs = list(result)
440448
return self._convert_search_result_to_documents(azure_docs)

integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"}
88

99

10-
def normalize_filters(filters: Dict[str, Any]) -> str:
10+
def _normalize_filters(filters: Dict[str, Any]) -> str:
1111
"""
1212
Converts Haystack filters in Azure AI Search compatible filters.
1313
"""

integrations/azure_ai_search/tests/conftest.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from azure.core.credentials import AzureKeyCredential
77
from azure.core.exceptions import ResourceNotFoundError
88
from azure.search.documents.indexes import SearchIndexClient
9+
from haystack import logging
910
from haystack.document_stores.types import DuplicatePolicy
1011

1112
from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore
1213

1314
# This is the approximate time in seconds it takes for the documents to be available in Azure Search index
14-
SLEEP_TIME_IN_SECONDS = 5
15+
SLEEP_TIME_IN_SECONDS = 10
16+
MAX_WAIT_TIME_FOR_INDEX_DELETION = 5
1517

1618

1719
@pytest.fixture()
@@ -46,23 +48,35 @@ def document_store(request):
4648

4749
# Override some methods to wait for the documents to be available
4850
original_write_documents = store.write_documents
51+
original_delete_documents = store.delete_documents
4952

5053
def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE):
5154
written_docs = original_write_documents(documents, policy)
5255
time.sleep(SLEEP_TIME_IN_SECONDS)
5356
return written_docs
5457

55-
original_delete_documents = store.delete_documents
56-
5758
def delete_documents_and_wait(filters):
5859
original_delete_documents(filters)
5960
time.sleep(SLEEP_TIME_IN_SECONDS)
6061

62+
# Helper function to wait for the index to be deleted, needed to cover latency
63+
def wait_for_index_deletion(client, index_name):
64+
start_time = time.time()
65+
while time.time() - start_time < MAX_WAIT_TIME_FOR_INDEX_DELETION:
66+
if index_name not in client.list_index_names():
67+
return True
68+
time.sleep(1)
69+
return False
70+
6171
store.write_documents = write_documents_and_wait
6272
store.delete_documents = delete_documents_and_wait
6373

6474
yield store
6575
try:
6676
client.delete_index(index_name)
77+
if not wait_for_index_deletion(client, index_name):
78+
logging.error(f"Index {index_name} was not properly deleted.")
6779
except ResourceNotFoundError:
68-
pass
80+
logging.info(f"Index {index_name} was already deleted or not found.")
81+
except Exception as e:
82+
logging.error(f"Unexpected error when deleting index {index_name}: {e}")

0 commit comments

Comments
 (0)