Skip to content

Commit 4da8be3

Browse files
committed
feat(cassandra/astradb): hybrid search support (langflow-ai#2396)
* cassandra/astradb: hybrid search support * fix * fix (cherry picked from commit 30c369f)
1 parent 9aa8799 commit 4da8be3

File tree

3 files changed

+186
-54
lines changed

3 files changed

+186
-54
lines changed

src/backend/base/langflow/base/vectorstores/model.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
7979
"""
8080
vector_store = self.build_vector_store()
8181
if hasattr(vector_store, "as_retriever"):
82-
retriever = vector_store.as_retriever()
82+
retriever = vector_store.as_retriever(**self.get_retriever_kwargs())
8383
if self.status is None:
8484
self.status = "Retriever built successfully."
8585
return retriever
@@ -106,3 +106,9 @@ def search_documents(self) -> List[Data]:
106106
)
107107
self.status = search_results
108108
return search_results
109+
110+
def get_retriever_kwargs(self):
111+
"""
112+
Get the retriever kwargs. Implementations can override this method to provide custom retriever kwargs.
113+
"""
114+
return {}

src/backend/base/langflow/components/vectorstores/AstraDB.py

+65-23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from loguru import logger
22

3+
from langchain_core.vectorstores import VectorStore
34
from langflow.base.vectorstores.model import LCVectorStoreComponent
5+
from langflow.helpers import docs_to_data
6+
from langflow.inputs import FloatInput, DictInput
47
from langflow.io import (
58
BoolInput,
69
DataInput,
@@ -20,6 +23,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
2023
documentation: str = "https://python.langchain.com/docs/integrations/vectorstores/astradb"
2124
icon: str = "AstraDB"
2225

26+
_cached_vectorstore: VectorStore = None
27+
2328
inputs = [
2429
StrInput(
2530
name="collection_name",
@@ -124,23 +129,40 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
124129
info="Optional dictionary defining the indexing policy for the collection.",
125130
advanced=True,
126131
),
132+
IntInput(
133+
name="number_of_results",
134+
display_name="Number of Results",
135+
info="Number of results to return.",
136+
advanced=True,
137+
value=4,
138+
),
127139
DropdownInput(
128140
name="search_type",
129141
display_name="Search Type",
130-
options=["Similarity", "MMR"],
142+
info="Search type to use",
143+
options=["Similarity", "Similarity with score threshold", "MMR (Max Marginal Relevance)"],
131144
value="Similarity",
132145
advanced=True,
133146
),
134-
IntInput(
135-
name="number_of_results",
136-
display_name="Number of Results",
137-
info="Number of results to return.",
147+
FloatInput(
148+
name="search_score_threshold",
149+
display_name="Search Score Threshold",
150+
info="Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')",
151+
value=0,
138152
advanced=True,
139-
value=4,
153+
),
154+
DictInput(
155+
name="search_filter",
156+
display_name="Search Metadata Filter",
157+
info="Optional dictionary of filters to apply to the search query.",
158+
advanced=True,
159+
is_list=True,
140160
),
141161
]
142162

143163
def _build_vector_store_no_ingest(self):
164+
if self._cached_vectorstore:
165+
return self._cached_vectorstore
144166
try:
145167
from langchain_astradb import AstraDBVectorStore
146168
from langchain_astradb.utils.astradb import SetupMode
@@ -199,13 +221,13 @@ def _build_vector_store_no_ingest(self):
199221
except Exception as e:
200222
raise ValueError(f"Error initializing AstraDBVectorStore: {str(e)}") from e
201223

224+
self._cached_vectorstore = vector_store
225+
202226
return vector_store
203227

204228
def build_vector_store(self):
205229
vector_store = self._build_vector_store_no_ingest()
206-
if hasattr(self, "ingest_data") and self.ingest_data:
207-
logger.debug("Ingesting data into the Vector Store.")
208-
self._add_documents_to_vector_store(vector_store)
230+
self._add_documents_to_vector_store(vector_store)
209231
return vector_store
210232

211233
def _add_documents_to_vector_store(self, vector_store):
@@ -216,7 +238,7 @@ def _add_documents_to_vector_store(self, vector_store):
216238
else:
217239
raise ValueError("Vector Store Inputs must be Data objects.")
218240

219-
if documents and self.embedding is not None:
241+
if documents:
220242
logger.debug(f"Adding {len(documents)} documents to the Vector Store.")
221243
try:
222244
vector_store.add_documents(documents)
@@ -225,36 +247,56 @@ def _add_documents_to_vector_store(self, vector_store):
225247
else:
226248
logger.debug("No documents to add to the Vector Store.")
227249

250+
def _map_search_type(self):
251+
if self.search_type == "Similarity with score threshold":
252+
return "similarity_score_threshold"
253+
elif self.search_type == "MMR (Max Marginal Relevance)":
254+
return "mmr"
255+
else:
256+
return "similarity"
257+
228258
def search_documents(self) -> list[Data]:
229259
vector_store = self._build_vector_store_no_ingest()
260+
self._add_documents_to_vector_store(vector_store)
230261

231262
logger.debug(f"Search input: {self.search_input}")
232263
logger.debug(f"Search type: {self.search_type}")
233264
logger.debug(f"Number of results: {self.number_of_results}")
234265

235266
if self.search_input and isinstance(self.search_input, str) and self.search_input.strip():
236267
try:
237-
if self.search_type == "Similarity":
238-
docs = vector_store.similarity_search(
239-
query=self.search_input,
240-
k=self.number_of_results,
241-
)
242-
elif self.search_type == "MMR":
243-
docs = vector_store.max_marginal_relevance_search(
244-
query=self.search_input,
245-
k=self.number_of_results,
246-
)
247-
else:
248-
raise ValueError(f"Invalid search type: {self.search_type}")
268+
search_type = self._map_search_type()
269+
search_args = self._build_search_args()
270+
271+
docs = vector_store.search(query=self.search_input, search_type=search_type, **search_args)
249272
except Exception as e:
250273
raise ValueError(f"Error performing search in AstraDBVectorStore: {str(e)}") from e
251274

252275
logger.debug(f"Retrieved documents: {len(docs)}")
253276

254-
data = [Data.from_document(doc) for doc in docs]
277+
data = docs_to_data(docs)
255278
logger.debug(f"Converted documents to data: {len(data)}")
256279
self.status = data
257280
return data
258281
else:
259282
logger.debug("No search input provided. Skipping search.")
260283
return []
284+
285+
def _build_search_args(self):
286+
args = {
287+
"k": self.number_of_results,
288+
"score_threshold": self.search_score_threshold,
289+
}
290+
291+
if self.search_filter:
292+
clean_filter = {k: v for k, v in self.search_filter.items() if k and v}
293+
if len(clean_filter) > 0:
294+
args["filter"] = clean_filter
295+
return args
296+
297+
def get_retriever_kwargs(self):
298+
search_args = self._build_search_args()
299+
return {
300+
"search_type": self._map_search_type(),
301+
"search_kwargs": search_args,
302+
}

0 commit comments

Comments
 (0)