1
1
from loguru import logger
2
2
3
+ from langchain_core .vectorstores import VectorStore
3
4
from langflow .base .vectorstores .model import LCVectorStoreComponent
5
+ from langflow .helpers import docs_to_data
6
+ from langflow .inputs import FloatInput , DictInput
4
7
from langflow .io import (
5
8
BoolInput ,
6
9
DataInput ,
@@ -20,6 +23,8 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
20
23
documentation : str = "https://python.langchain.com/docs/integrations/vectorstores/astradb"
21
24
icon : str = "AstraDB"
22
25
26
+ _cached_vectorstore : VectorStore = None
27
+
23
28
inputs = [
24
29
StrInput (
25
30
name = "collection_name" ,
@@ -124,23 +129,40 @@ class AstraVectorStoreComponent(LCVectorStoreComponent):
124
129
info = "Optional dictionary defining the indexing policy for the collection." ,
125
130
advanced = True ,
126
131
),
132
+ IntInput (
133
+ name = "number_of_results" ,
134
+ display_name = "Number of Results" ,
135
+ info = "Number of results to return." ,
136
+ advanced = True ,
137
+ value = 4 ,
138
+ ),
127
139
DropdownInput (
128
140
name = "search_type" ,
129
141
display_name = "Search Type" ,
130
- options = ["Similarity" , "MMR" ],
142
+ info = "Search type to use" ,
143
+ options = ["Similarity" , "Similarity with score threshold" , "MMR (Max Marginal Relevance)" ],
131
144
value = "Similarity" ,
132
145
advanced = True ,
133
146
),
134
- IntInput (
135
- name = "number_of_results" ,
136
- display_name = "Number of Results" ,
137
- info = "Number of results to return." ,
147
+ FloatInput (
148
+ name = "search_score_threshold" ,
149
+ display_name = "Search Score Threshold" ,
150
+ info = "Minimum similarity score threshold for search results. (when using 'Similarity with score threshold')" ,
151
+ value = 0 ,
138
152
advanced = True ,
139
- value = 4 ,
153
+ ),
154
+ DictInput (
155
+ name = "search_filter" ,
156
+ display_name = "Search Metadata Filter" ,
157
+ info = "Optional dictionary of filters to apply to the search query." ,
158
+ advanced = True ,
159
+ is_list = True ,
140
160
),
141
161
]
142
162
143
163
def _build_vector_store_no_ingest (self ):
164
+ if self ._cached_vectorstore :
165
+ return self ._cached_vectorstore
144
166
try :
145
167
from langchain_astradb import AstraDBVectorStore
146
168
from langchain_astradb .utils .astradb import SetupMode
@@ -199,13 +221,13 @@ def _build_vector_store_no_ingest(self):
199
221
except Exception as e :
200
222
raise ValueError (f"Error initializing AstraDBVectorStore: { str (e )} " ) from e
201
223
224
+ self ._cached_vectorstore = vector_store
225
+
202
226
return vector_store
203
227
204
228
def build_vector_store (self ):
205
229
vector_store = self ._build_vector_store_no_ingest ()
206
- if hasattr (self , "ingest_data" ) and self .ingest_data :
207
- logger .debug ("Ingesting data into the Vector Store." )
208
- self ._add_documents_to_vector_store (vector_store )
230
+ self ._add_documents_to_vector_store (vector_store )
209
231
return vector_store
210
232
211
233
def _add_documents_to_vector_store (self , vector_store ):
@@ -216,7 +238,7 @@ def _add_documents_to_vector_store(self, vector_store):
216
238
else :
217
239
raise ValueError ("Vector Store Inputs must be Data objects." )
218
240
219
- if documents and self . embedding is not None :
241
+ if documents :
220
242
logger .debug (f"Adding { len (documents )} documents to the Vector Store." )
221
243
try :
222
244
vector_store .add_documents (documents )
@@ -225,36 +247,56 @@ def _add_documents_to_vector_store(self, vector_store):
225
247
else :
226
248
logger .debug ("No documents to add to the Vector Store." )
227
249
250
+ def _map_search_type (self ):
251
+ if self .search_type == "Similarity with score threshold" :
252
+ return "similarity_score_threshold"
253
+ elif self .search_type == "MMR (Max Marginal Relevance)" :
254
+ return "mmr"
255
+ else :
256
+ return "similarity"
257
+
228
258
def search_documents (self ) -> list [Data ]:
229
259
vector_store = self ._build_vector_store_no_ingest ()
260
+ self ._add_documents_to_vector_store (vector_store )
230
261
231
262
logger .debug (f"Search input: { self .search_input } " )
232
263
logger .debug (f"Search type: { self .search_type } " )
233
264
logger .debug (f"Number of results: { self .number_of_results } " )
234
265
235
266
if self .search_input and isinstance (self .search_input , str ) and self .search_input .strip ():
236
267
try :
237
- if self .search_type == "Similarity" :
238
- docs = vector_store .similarity_search (
239
- query = self .search_input ,
240
- k = self .number_of_results ,
241
- )
242
- elif self .search_type == "MMR" :
243
- docs = vector_store .max_marginal_relevance_search (
244
- query = self .search_input ,
245
- k = self .number_of_results ,
246
- )
247
- else :
248
- raise ValueError (f"Invalid search type: { self .search_type } " )
268
+ search_type = self ._map_search_type ()
269
+ search_args = self ._build_search_args ()
270
+
271
+ docs = vector_store .search (query = self .search_input , search_type = search_type , ** search_args )
249
272
except Exception as e :
250
273
raise ValueError (f"Error performing search in AstraDBVectorStore: { str (e )} " ) from e
251
274
252
275
logger .debug (f"Retrieved documents: { len (docs )} " )
253
276
254
- data = [ Data . from_document ( doc ) for doc in docs ]
277
+ data = docs_to_data ( docs )
255
278
logger .debug (f"Converted documents to data: { len (data )} " )
256
279
self .status = data
257
280
return data
258
281
else :
259
282
logger .debug ("No search input provided. Skipping search." )
260
283
return []
284
+
285
+ def _build_search_args (self ):
286
+ args = {
287
+ "k" : self .number_of_results ,
288
+ "score_threshold" : self .search_score_threshold ,
289
+ }
290
+
291
+ if self .search_filter :
292
+ clean_filter = {k : v for k , v in self .search_filter .items () if k and v }
293
+ if len (clean_filter ) > 0 :
294
+ args ["filter" ] = clean_filter
295
+ return args
296
+
297
+ def get_retriever_kwargs (self ):
298
+ search_args = self ._build_search_args ()
299
+ return {
300
+ "search_type" : self ._map_search_type (),
301
+ "search_kwargs" : search_args ,
302
+ }
0 commit comments