Skip to content

Commit fa2ccfc

Browse files
committed
Cleanup + Lint
1 parent 0aef582 commit fa2ccfc

File tree

5 files changed

+82
-119
lines changed

5 files changed

+82
-119
lines changed

api/configs/middleware/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,6 @@ class MiddlewareConfig(
222222
TiDBVectorConfig,
223223
WeaviateConfig,
224224
ElasticsearchConfig,
225-
CouchbaseConfig
225+
CouchbaseConfig,
226226
):
227227
pass

api/configs/middleware/vdb/couchbase_config.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,26 @@ class CouchbaseConfig(BaseModel):
99
"""
1010

1111
COUCHBASE_CONNECTION_STRING: Optional[str] = Field(
12-
description='COUCHBASE connection string',
12+
description="COUCHBASE connection string",
1313
default=None,
1414
)
1515

1616
COUCHBASE_USER: Optional[str] = Field(
17-
description='COUCHBASE user',
17+
description="COUCHBASE user",
1818
default=None,
1919
)
2020

2121
COUCHBASE_PASSWORD: Optional[str] = Field(
22-
description='COUCHBASE password',
22+
description="COUCHBASE password",
2323
default=None,
2424
)
2525

2626
COUCHBASE_BUCKET_NAME: Optional[str] = Field(
27-
description='COUCHBASE bucket name',
27+
description="COUCHBASE bucket name",
2828
default=None,
29-
3029
)
3130

3231
COUCHBASE_SCOPE_NAME: Optional[str] = Field(
33-
description='COUCHBASE scope name',
32+
description="COUCHBASE scope name",
3433
default=None,
35-
3634
)

api/core/rag/datasource/vdb/couchbase/couchbase_vector.py

+70-88
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,23 @@ class CouchbaseConfig(BaseModel):
3434
bucket_name: str
3535
scope_name: str
3636

37-
@model_validator(mode='before')
37+
@model_validator(mode="before")
38+
@classmethod
3839
def validate_config(cls, values: dict) -> dict:
39-
if not values.get('connection_string'):
40+
if not values.get("connection_string"):
4041
raise ValueError("config COUCHBASE_CONNECTION_STRING is required")
41-
if not values.get('user'):
42+
if not values.get("user"):
4243
raise ValueError("config COUCHBASE_USER is required")
43-
if not values.get('password'):
44+
if not values.get("password"):
4445
raise ValueError("config COUCHBASE_PASSWORD is required")
45-
if not values.get('bucket_name'):
46+
if not values.get("bucket_name"):
4647
raise ValueError("config COUCHBASE_PASSWORD is required")
47-
if not values.get('scope_name'):
48+
if not values.get("scope_name"):
4849
raise ValueError("config COUCHBASE_SCOPE_NAME is required")
4950
return values
50-
51-
class CouchbaseVector(BaseVector):
5251

52+
53+
class CouchbaseVector(BaseVector):
5354
def __init__(self, collection_name: str, config: CouchbaseConfig):
5455
super().__init__(collection_name)
5556
self._client_config = config
@@ -68,14 +69,14 @@ def __init__(self, collection_name: str, config: CouchbaseConfig):
6869
self._cluster.wait_until_ready(timedelta(seconds=5))
6970

7071
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
71-
index_id = str(uuid.uuid4()).replace('-','')
72-
self._create_collection(uuid=index_id,vector_length=len(embeddings[0]))
72+
index_id = str(uuid.uuid4()).replace("-", "")
73+
self._create_collection(uuid=index_id, vector_length=len(embeddings[0]))
7374
self.add_texts(texts, embeddings)
7475

7576
def _create_collection(self, vector_length: int, uuid: str):
76-
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name)
77+
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
7778
with redis_client.lock(lock_name, timeout=20):
78-
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
79+
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
7980
if redis_client.get(collection_exist_cache_key):
8081
return
8182
if self._collection_exists(self._collection_name):
@@ -165,10 +166,14 @@ def _create_collection(self, vector_length: int, uuid: str):
165166
"sourceParams": { }
166167
}
167168
""")
168-
index_definition['name'] = self._collection_name + '_search'
169-
index_definition['uuid'] = uuid
170-
index_definition['params']['mapping']['types']['collection_name']['properties']['embedding']['fields'][0]['dims'] = vector_length
171-
index_definition['params']['mapping']['types'][self._scope_name + '.' + self._collection_name] = index_definition['params']['mapping']['types'].pop('collection_name')
169+
index_definition["name"] = self._collection_name + "_search"
170+
index_definition["uuid"] = uuid
171+
index_definition["params"]["mapping"]["types"]["collection_name"]["properties"]["embedding"]["fields"][0][
172+
"dims"
173+
] = vector_length
174+
index_definition["params"]["mapping"]["types"][self._scope_name + "." + self._collection_name] = (
175+
index_definition["params"]["mapping"]["types"].pop("collection_name")
176+
)
172177
time.sleep(2)
173178
index_manager.upsert_index(
174179
SearchIndex(
@@ -206,32 +211,27 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
206211
doc_ids = []
207212

208213
documents_to_insert = [
209-
{
210-
'text': text,
211-
'embedding': vector,
212-
'metadata': metadata
213-
}
214-
for id, text, vector, metadata in zip(
215-
uuids, texts, embeddings, metadatas
216-
)
214+
{"text": text, "embedding": vector, "metadata": metadata}
215+
for id, text, vector, metadata in zip(uuids, texts, embeddings, metadatas)
217216
]
218-
for doc,id in zip(documents_to_insert,uuids):
219-
result = self._scope.collection(self._collection_name).upsert(id,doc)
217+
for doc, id in zip(documents_to_insert, uuids):
218+
result = self._scope.collection(self._collection_name).upsert(id, doc)
220219

221-
222-
223-
224220
doc_ids.extend(uuids)
225-
221+
226222
return doc_ids
227223

228224
def text_exists(self, id: str) -> bool:
229225
# Use a parameterized query for safety and correctness
230-
query = f"SELECT COUNT(1) AS count FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name} WHERE META().id = $doc_id"
226+
query = f"""
227+
SELECT COUNT(1) AS count FROM
228+
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
229+
WHERE META().id = $doc_id
230+
"""
231231
# Pass the id as a parameter to the query
232-
result = self._cluster.query(query, named_parameters={"doc_id": id})
232+
result = self._cluster.query(query, named_parameters={"doc_id": id}).execute()
233233
for row in result:
234-
return row['count'] > 0
234+
return row["count"] > 0
235235
return False # Return False if no rows are returned
236236

237237
def delete_by_ids(self, ids: list[str]) -> None:
@@ -240,72 +240,61 @@ def delete_by_ids(self, ids: list[str]) -> None:
240240
WHERE META().id IN $doc_ids;
241241
"""
242242
try:
243-
result = self._cluster.query(query, named_parameters={'doc_ids': ids})
244-
# force evaluation of the query to ensure deletion occurs
245-
list(result)
243+
self._cluster.query(query, named_parameters={"doc_ids": ids}).execute()
246244
except Exception as e:
247245
logger.error(e)
248246

249247
def delete_by_document_id(self, document_id: str):
250248
query = f"""
251-
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
249+
DELETE FROM
250+
`{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
252251
WHERE META().id = $doc_id;
253252
"""
254-
result = self._cluster.query(query,named_parameters={'doc_id':document_id})
255-
# force evaluation of the query to ensure deletion occurs
256-
list(result)
253+
self._cluster.query(query, named_parameters={"doc_id": document_id}).execute()
257254

258255
# def get_ids_by_metadata_field(self, key: str, value: str):
259256
# query = f"""
260-
# SELECT id FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
257+
# SELECT id FROM
258+
# `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
261259
# WHERE `metadata.{key}` = $value;
262260
# """
263261
# result = self._cluster.query(query, named_parameters={'value':value})
264262
# return [row['id'] for row in result.rows()]
265263

266-
267264
def delete_by_metadata_field(self, key: str, value: str) -> None:
268265
query = f"""
269266
DELETE FROM `{self._client_config.bucket_name}`.{self._client_config.scope_name}.{self._collection_name}
270267
WHERE metadata.{key} = $value;
271268
"""
272-
result = self._cluster.query(query, named_parameters={'value':value})
273-
# force evaluation of the query to ensure deletion occurs
274-
list(result)
275-
276-
def search_by_vector(
277-
self,
278-
query_vector: list[float],
279-
**kwargs: Any
280-
) -> list[Document]:
269+
self._cluster.query(query, named_parameters={"value": value}).execute()
281270

271+
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
282272
top_k = kwargs.get("top_k", 5)
283-
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
273+
score_threshold = kwargs.get("score_threshold") or 0.0
284274

285275
search_req = search.SearchRequest.create(
286276
VectorSearch.from_vector_query(
287277
VectorQuery(
288-
'embedding',
278+
"embedding",
289279
query_vector,
290280
top_k,
291-
292281
)
293282
)
294-
)
283+
)
295284
try:
296285
search_iter = self._scope.search(
297-
self._collection_name + '_search',
298-
search_req,
299-
SearchOptions(limit=top_k, collections=[self._collection_name],fields=['*']),
300-
)
286+
self._collection_name + "_search",
287+
search_req,
288+
SearchOptions(limit=top_k, collections=[self._collection_name], fields=["*"]),
289+
)
301290

302291
docs = []
303292
# Parse the results
304293
for row in search_iter.rows():
305-
text = row.fields.pop('text')
294+
text = row.fields.pop("text")
306295
metadata = self._format_metadata(row.fields)
307296
score = row.score
308-
metadata['score'] = score
297+
metadata["score"] = score
309298
doc = Document(page_content=text, metadata=metadata)
310299
if score >= score_threshold:
311300
docs.append(doc)
@@ -314,41 +303,36 @@ def search_by_vector(
314303

315304
return docs
316305

317-
def search_by_full_text(
318-
self, query: str,
319-
**kwargs: Any
320-
) -> list[Document]:
321-
top_k=kwargs.get('top_k', 2)
306+
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
307+
top_k = kwargs.get("top_k", 2)
322308
try:
323-
CBrequest = search.SearchRequest.create(search.QueryStringQuery('text:'+query))
324-
search_iter = self._scope.search(self._collection_name + '_search',
325-
CBrequest,
326-
SearchOptions(limit=top_k,fields=['*']))
309+
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
310+
search_iter = self._scope.search(
311+
self._collection_name + "_search", CBrequest, SearchOptions(limit=top_k, fields=["*"])
312+
)
327313

328-
329314
docs = []
330315
for row in search_iter.rows():
331-
text = row.fields.pop('text')
316+
text = row.fields.pop("text")
332317
metadata = self._format_metadata(row.fields)
333318
score = row.score
334-
metadata['score'] = score
319+
metadata["score"] = score
335320
doc = Document(page_content=text, metadata=metadata)
336321
docs.append(doc)
337322

338323
except Exception as e:
339324
raise ValueError(f"Search failed with error: {e}")
340-
325+
341326
return docs
342-
327+
343328
def delete(self):
344329
manager = self._bucket.collections()
345330
scopes = manager.get_all_scopes()
346331

347-
348332
for scope in scopes:
349333
for collection in scope.collections:
350334
if collection.name == self._collection_name:
351-
manager.drop_collection('_default', self._collection_name)
335+
manager.drop_collection("_default", self._collection_name)
352336

353337
def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]:
354338
"""Helper method to format the metadata from the Couchbase Search API.
@@ -362,16 +346,15 @@ def _format_metadata(self, row_fields: dict[str, Any]) -> dict[str, Any]:
362346
for key, value in row_fields.items():
363347
# Couchbase Search returns the metadata key with a prefix
364348
# `metadata.` We remove it to get the original metadata key
365-
if key.startswith('metadata'):
366-
new_key = key.split('metadata' + ".")[-1]
349+
if key.startswith("metadata"):
350+
new_key = key.split("metadata" + ".")[-1]
367351
metadata[new_key] = value
368352
else:
369353
metadata[key] = value
370354

371355
return metadata
372356

373357

374-
375358
class CouchbaseVectorFactory(AbstractVectorFactory):
376359
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> CouchbaseVector:
377360
if dataset.index_struct_dict:
@@ -380,17 +363,16 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
380363
else:
381364
dataset_id = dataset.id
382365
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
383-
dataset.index_struct = json.dumps(
384-
self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name))
366+
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.COUCHBASE, collection_name))
385367

386368
config = current_app.config
387369
return CouchbaseVector(
388370
collection_name=collection_name,
389371
config=CouchbaseConfig(
390-
connection_string=config.get('COUCHBASE_CONNECTION_STRING'),
391-
user=config.get('COUCHBASE_USER'),
392-
password=config.get('COUCHBASE_PASSWORD'),
393-
bucket_name=config.get('COUCHBASE_BUCKET_NAME'),
394-
scope_name=config.get('COUCHBASE_SCOPE_NAME'),
395-
)
372+
connection_string=config.get("COUCHBASE_CONNECTION_STRING"),
373+
user=config.get("COUCHBASE_USER"),
374+
password=config.get("COUCHBASE_PASSWORD"),
375+
bucket_name=config.get("COUCHBASE_BUCKET_NAME"),
376+
scope_name=config.get("COUCHBASE_SCOPE_NAME"),
377+
),
396378
)

api/core/rag/datasource/vdb/vector_type.py

-18
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33

44
class VectorType(str, Enum):
5-
<<<<<<< HEAD
65
ANALYTICDB = "analyticdb"
76
CHROMA = "chroma"
87
MILVUS = "milvus"
@@ -18,20 +17,3 @@ class VectorType(str, Enum):
1817
ORACLE = "oracle"
1918
ELASTICSEARCH = "elasticsearch"
2019
COUCHBASE = "couchbase"
21-
=======
22-
ANALYTICDB = 'analyticdb'
23-
CHROMA = 'chroma'
24-
MILVUS = 'milvus'
25-
MYSCALE = 'myscale'
26-
PGVECTOR = 'pgvector'
27-
PGVECTO_RS = 'pgvecto-rs'
28-
QDRANT = 'qdrant'
29-
RELYT = 'relyt'
30-
TIDB_VECTOR = 'tidb_vector'
31-
WEAVIATE = 'weaviate'
32-
OPENSEARCH = 'opensearch'
33-
TENCENT = 'tencent'
34-
ORACLE = 'oracle'
35-
ELASTICSEARCH = 'elasticsearch'
36-
COUCHBASE = 'couchbase'
37-
>>>>>>> 8d7e8c48 (Cleanup)

api/tests/integration_tests/vdb/couchbase/test_couchbase.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ def __init__(self):
1616
self.vector = CouchbaseVector(
1717
collection_name=self.collection_name,
1818
config=CouchbaseConfig(
19-
connection_string = '127.0.0.1',
20-
user = 'Administrator',
21-
password = 'password',
22-
bucket_name = 'Embeddings',
23-
scope_name = '_default',
19+
connection_string="127.0.0.1",
20+
user="Administrator",
21+
password="password",
22+
bucket_name="Embeddings",
23+
scope_name="_default",
2424
),
2525
)
2626

27+
2728
def test_couchbase(setup_mock_redis):
2829
CouchbaseTest().run_all_tests()

0 commit comments

Comments
 (0)