@@ -34,22 +34,23 @@ class CouchbaseConfig(BaseModel):
34
34
bucket_name : str
35
35
scope_name : str
36
36
37
- @model_validator (mode = 'before' )
37
+ @model_validator (mode = "before" )
38
+ @classmethod
38
39
def validate_config (cls , values : dict ) -> dict :
39
- if not values .get (' connection_string' ):
40
+ if not values .get (" connection_string" ):
40
41
raise ValueError ("config COUCHBASE_CONNECTION_STRING is required" )
41
- if not values .get (' user' ):
42
+ if not values .get (" user" ):
42
43
raise ValueError ("config COUCHBASE_USER is required" )
43
- if not values .get (' password' ):
44
+ if not values .get (" password" ):
44
45
raise ValueError ("config COUCHBASE_PASSWORD is required" )
45
- if not values .get (' bucket_name' ):
46
+ if not values .get (" bucket_name" ):
46
47
raise ValueError ("config COUCHBASE_PASSWORD is required" )
47
- if not values .get (' scope_name' ):
48
+ if not values .get (" scope_name" ):
48
49
raise ValueError ("config COUCHBASE_SCOPE_NAME is required" )
49
50
return values
50
-
51
- class CouchbaseVector (BaseVector ):
52
51
52
+
53
+ class CouchbaseVector (BaseVector ):
53
54
def __init__ (self , collection_name : str , config : CouchbaseConfig ):
54
55
super ().__init__ (collection_name )
55
56
self ._client_config = config
@@ -68,14 +69,14 @@ def __init__(self, collection_name: str, config: CouchbaseConfig):
68
69
self ._cluster .wait_until_ready (timedelta (seconds = 5 ))
69
70
70
71
def create (self , texts : list [Document ], embeddings : list [list [float ]], ** kwargs ):
71
- index_id = str (uuid .uuid4 ()).replace ('-' , '' )
72
- self ._create_collection (uuid = index_id ,vector_length = len (embeddings [0 ]))
72
+ index_id = str (uuid .uuid4 ()).replace ("-" , "" )
73
+ self ._create_collection (uuid = index_id , vector_length = len (embeddings [0 ]))
73
74
self .add_texts (texts , embeddings )
74
75
75
76
def _create_collection (self , vector_length : int , uuid : str ):
76
- lock_name = ' vector_indexing_lock_{}' .format (self ._collection_name )
77
+ lock_name = " vector_indexing_lock_{}" .format (self ._collection_name )
77
78
with redis_client .lock (lock_name , timeout = 20 ):
78
- collection_exist_cache_key = ' vector_indexing_{}' .format (self ._collection_name )
79
+ collection_exist_cache_key = " vector_indexing_{}" .format (self ._collection_name )
79
80
if redis_client .get (collection_exist_cache_key ):
80
81
return
81
82
if self ._collection_exists (self ._collection_name ):
@@ -165,10 +166,14 @@ def _create_collection(self, vector_length: int, uuid: str):
165
166
"sourceParams": { }
166
167
}
167
168
""" )
168
- index_definition ['name' ] = self ._collection_name + '_search'
169
- index_definition ['uuid' ] = uuid
170
- index_definition ['params' ]['mapping' ]['types' ]['collection_name' ]['properties' ]['embedding' ]['fields' ][0 ]['dims' ] = vector_length
171
- index_definition ['params' ]['mapping' ]['types' ][self ._scope_name + '.' + self ._collection_name ] = index_definition ['params' ]['mapping' ]['types' ].pop ('collection_name' )
169
+ index_definition ["name" ] = self ._collection_name + "_search"
170
+ index_definition ["uuid" ] = uuid
171
+ index_definition ["params" ]["mapping" ]["types" ]["collection_name" ]["properties" ]["embedding" ]["fields" ][0 ][
172
+ "dims"
173
+ ] = vector_length
174
+ index_definition ["params" ]["mapping" ]["types" ][self ._scope_name + "." + self ._collection_name ] = (
175
+ index_definition ["params" ]["mapping" ]["types" ].pop ("collection_name" )
176
+ )
172
177
time .sleep (2 )
173
178
index_manager .upsert_index (
174
179
SearchIndex (
@@ -206,32 +211,27 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
206
211
doc_ids = []
207
212
208
213
documents_to_insert = [
209
- {
210
- 'text' : text ,
211
- 'embedding' : vector ,
212
- 'metadata' : metadata
213
- }
214
- for id , text , vector , metadata in zip (
215
- uuids , texts , embeddings , metadatas
216
- )
214
+ {"text" : text , "embedding" : vector , "metadata" : metadata }
215
+ for id , text , vector , metadata in zip (uuids , texts , embeddings , metadatas )
217
216
]
218
- for doc ,id in zip (documents_to_insert ,uuids ):
219
- result = self ._scope .collection (self ._collection_name ).upsert (id ,doc )
217
+ for doc , id in zip (documents_to_insert , uuids ):
218
+ result = self ._scope .collection (self ._collection_name ).upsert (id , doc )
220
219
221
-
222
-
223
-
224
220
doc_ids .extend (uuids )
225
-
221
+
226
222
return doc_ids
227
223
228
224
def text_exists (self , id : str ) -> bool :
229
225
# Use a parameterized query for safety and correctness
230
- query = f"SELECT COUNT(1) AS count FROM `{ self ._client_config .bucket_name } `.{ self ._client_config .scope_name } .{ self ._collection_name } WHERE META().id = $doc_id"
226
+ query = f"""
227
+ SELECT COUNT(1) AS count FROM
228
+ `{ self ._client_config .bucket_name } `.{ self ._client_config .scope_name } .{ self ._collection_name }
229
+ WHERE META().id = $doc_id
230
+ """
231
231
# Pass the id as a parameter to the query
232
- result = self ._cluster .query (query , named_parameters = {"doc_id" : id })
232
+ result = self ._cluster .query (query , named_parameters = {"doc_id" : id }). execute ()
233
233
for row in result :
234
- return row [' count' ] > 0
234
+ return row [" count" ] > 0
235
235
return False # Return False if no rows are returned
236
236
237
237
def delete_by_ids (self , ids : list [str ]) -> None :
@@ -240,72 +240,61 @@ def delete_by_ids(self, ids: list[str]) -> None:
240
240
WHERE META().id IN $doc_ids;
241
241
"""
242
242
try :
243
- result = self ._cluster .query (query , named_parameters = {'doc_ids' : ids })
244
- # force evaluation of the query to ensure deletion occurs
245
- list (result )
243
+ self ._cluster .query (query , named_parameters = {"doc_ids" : ids }).execute ()
246
244
except Exception as e :
247
245
logger .error (e )
248
246
249
247
def delete_by_document_id (self , document_id : str ):
250
248
query = f"""
251
- DELETE FROM `{ self ._client_config .bucket_name } `.{ self ._client_config .scope_name } .{ self ._collection_name }
249
+ DELETE FROM
250
+ `{ self ._client_config .bucket_name } `.{ self ._client_config .scope_name } .{ self ._collection_name }
252
251
WHERE META().id = $doc_id;
253
252
"""
254
- result = self ._cluster .query (query ,named_parameters = {'doc_id' :document_id })
255
- # force evaluation of the query to ensure deletion occurs
256
- list (result )
253
+ self ._cluster .query (query , named_parameters = {"doc_id" : document_id }).execute ()
257
254
258
255
# def get_ids_by_metadata_field(self, key: str, value: str):
259
256
# query = f"""
260
- # SELECT id FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
257
+ # SELECT id FROM
258
+ # `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
261
259
# WHERE `metadata.{key}` = $value;
262
260
# """
263
261
# result = self._cluster.query(query, named_parameters={'value':value})
264
262
# return [row['id'] for row in result.rows()]
265
263
266
-
267
264
def delete_by_metadata_field (self , key : str , value : str ) -> None :
268
265
query = f"""
269
266
DELETE FROM `{ self ._client_config .bucket_name } `.{ self ._client_config .scope_name } .{ self ._collection_name }
270
267
WHERE metadata.{ key } = $value;
271
268
"""
272
- result = self ._cluster .query (query , named_parameters = {'value' :value })
273
- # force evaluation of the query to ensure deletion occurs
274
- list (result )
275
-
276
- def search_by_vector (
277
- self ,
278
- query_vector : list [float ],
279
- ** kwargs : Any
280
- ) -> list [Document ]:
269
+ self ._cluster .query (query , named_parameters = {"value" : value }).execute ()
281
270
271
+ def search_by_vector (self , query_vector : list [float ], ** kwargs : Any ) -> list [Document ]:
282
272
top_k = kwargs .get ("top_k" , 5 )
283
- score_threshold = kwargs .get ("score_threshold" ) if kwargs . get ( "score_threshold" ) else 0.0
273
+ score_threshold = kwargs .get ("score_threshold" ) or 0.0
284
274
285
275
search_req = search .SearchRequest .create (
286
276
VectorSearch .from_vector_query (
287
277
VectorQuery (
288
- ' embedding' ,
278
+ " embedding" ,
289
279
query_vector ,
290
280
top_k ,
291
-
292
281
)
293
282
)
294
- )
283
+ )
295
284
try :
296
285
search_iter = self ._scope .search (
297
- self ._collection_name + ' _search' ,
298
- search_req ,
299
- SearchOptions (limit = top_k , collections = [self ._collection_name ],fields = ['*' ]),
300
- )
286
+ self ._collection_name + " _search" ,
287
+ search_req ,
288
+ SearchOptions (limit = top_k , collections = [self ._collection_name ], fields = ["*" ]),
289
+ )
301
290
302
291
docs = []
303
292
# Parse the results
304
293
for row in search_iter .rows ():
305
- text = row .fields .pop (' text' )
294
+ text = row .fields .pop (" text" )
306
295
metadata = self ._format_metadata (row .fields )
307
296
score = row .score
308
- metadata [' score' ] = score
297
+ metadata [" score" ] = score
309
298
doc = Document (page_content = text , metadata = metadata )
310
299
if score >= score_threshold :
311
300
docs .append (doc )
@@ -314,41 +303,36 @@ def search_by_vector(
314
303
315
304
return docs
316
305
317
- def search_by_full_text (
318
- self , query : str ,
319
- ** kwargs : Any
320
- ) -> list [Document ]:
321
- top_k = kwargs .get ('top_k' , 2 )
306
+ def search_by_full_text (self , query : str , ** kwargs : Any ) -> list [Document ]:
307
+ top_k = kwargs .get ("top_k" , 2 )
322
308
try :
323
- CBrequest = search .SearchRequest .create (search .QueryStringQuery (' text:' + query ))
324
- search_iter = self ._scope .search (self . _collection_name + '_search' ,
325
- CBrequest ,
326
- SearchOptions ( limit = top_k , fields = [ '*' ]) )
309
+ CBrequest = search .SearchRequest .create (search .QueryStringQuery (" text:" + query ))
310
+ search_iter = self ._scope .search (
311
+ self . _collection_name + "_search" , CBrequest , SearchOptions ( limit = top_k , fields = [ "*" ])
312
+ )
327
313
328
-
329
314
docs = []
330
315
for row in search_iter .rows ():
331
- text = row .fields .pop (' text' )
316
+ text = row .fields .pop (" text" )
332
317
metadata = self ._format_metadata (row .fields )
333
318
score = row .score
334
- metadata [' score' ] = score
319
+ metadata [" score" ] = score
335
320
doc = Document (page_content = text , metadata = metadata )
336
321
docs .append (doc )
337
322
338
323
except Exception as e :
339
324
raise ValueError (f"Search failed with error: { e } " )
340
-
325
+
341
326
return docs
342
-
327
+
343
328
def delete (self ):
344
329
manager = self ._bucket .collections ()
345
330
scopes = manager .get_all_scopes ()
346
331
347
-
348
332
for scope in scopes :
349
333
for collection in scope .collections :
350
334
if collection .name == self ._collection_name :
351
- manager .drop_collection (' _default' , self ._collection_name )
335
+ manager .drop_collection (" _default" , self ._collection_name )
352
336
353
337
def _format_metadata (self , row_fields : dict [str , Any ]) -> dict [str , Any ]:
354
338
"""Helper method to format the metadata from the Couchbase Search API.
@@ -362,16 +346,15 @@ def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]:
362
346
for key , value in row_fields .items ():
363
347
# Couchbase Search returns the metadata key with a prefix
364
348
# `metadata.` We remove it to get the original metadata key
365
- if key .startswith (' metadata' ):
366
- new_key = key .split (' metadata' + "." )[- 1 ]
349
+ if key .startswith (" metadata" ):
350
+ new_key = key .split (" metadata" + "." )[- 1 ]
367
351
metadata [new_key ] = value
368
352
else :
369
353
metadata [key ] = value
370
354
371
355
return metadata
372
356
373
357
374
-
375
358
class CouchbaseVectorFactory (AbstractVectorFactory ):
376
359
def init_vector (self , dataset : Dataset , attributes : list , embeddings : Embeddings ) -> CouchbaseVector :
377
360
if dataset .index_struct_dict :
@@ -380,17 +363,16 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
380
363
else :
381
364
dataset_id = dataset .id
382
365
collection_name = Dataset .gen_collection_name_by_id (dataset_id )
383
- dataset .index_struct = json .dumps (
384
- self .gen_index_struct_dict (VectorType .COUCHBASE , collection_name ))
366
+ dataset .index_struct = json .dumps (self .gen_index_struct_dict (VectorType .COUCHBASE , collection_name ))
385
367
386
368
config = current_app .config
387
369
return CouchbaseVector (
388
370
collection_name = collection_name ,
389
371
config = CouchbaseConfig (
390
- connection_string = config .get (' COUCHBASE_CONNECTION_STRING' ),
391
- user = config .get (' COUCHBASE_USER' ),
392
- password = config .get (' COUCHBASE_PASSWORD' ),
393
- bucket_name = config .get (' COUCHBASE_BUCKET_NAME' ),
394
- scope_name = config .get (' COUCHBASE_SCOPE_NAME' ),
395
- )
372
+ connection_string = config .get (" COUCHBASE_CONNECTION_STRING" ),
373
+ user = config .get (" COUCHBASE_USER" ),
374
+ password = config .get (" COUCHBASE_PASSWORD" ),
375
+ bucket_name = config .get (" COUCHBASE_BUCKET_NAME" ),
376
+ scope_name = config .get (" COUCHBASE_SCOPE_NAME" ),
377
+ ),
396
378
)
0 commit comments