Skip to content

Added OceanBase as an option for the vector store in Dify #10010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ docker/volumes/unstructured/*
docker/volumes/pgvector/data/*
docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/*
docker/volumes/oceanbase/*

docker/nginx/conf.d/default.conf
docker/nginx/ssl/*
Expand Down
8 changes: 8 additions & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,14 @@ VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30

# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
OCEANBASE_VECTOR_USER=root@test
OCEANBASE_VECTOR_PASSWORD=
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G

# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
Expand Down
1 change: 1 addition & 0 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def migrate_knowledge_vector_database():
VectorType.VIKINGDB,
VectorType.UPSTASH,
VectorType.COUCHBASE,
VectorType.OCEANBASE,
}
page = 1
while True:
Expand Down
2 changes: 2 additions & 0 deletions api/configs/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
from configs.middleware.vdb.oracle_config import OracleConfig
from configs.middleware.vdb.pgvector_config import PGVectorConfig
Expand Down Expand Up @@ -257,5 +258,6 @@ class MiddlewareConfig(
VikingDBConfig,
UpstashConfig,
TidbOnQdrantConfig,
OceanBaseVectorConfig,
):
pass
35 changes: 35 additions & 0 deletions api/configs/middleware/vdb/oceanbase_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings


class OceanBaseVectorConfig(BaseSettings):
"""
Configuration settings for OceanBase Vector database
"""

OCEANBASE_VECTOR_HOST: Optional[str] = Field(
description="Hostname or IP address of the OceanBase Vector server (e.g. 'localhost')",
default=None,
)

OCEANBASE_VECTOR_PORT: Optional[PositiveInt] = Field(
description="Port number on which the OceanBase Vector server is listening (default is 2881)",
default=2881,
)

OCEANBASE_VECTOR_USER: Optional[str] = Field(
description="Username for authenticating with the OceanBase Vector database",
default=None,
)

OCEANBASE_VECTOR_PASSWORD: Optional[str] = Field(
description="Password for authenticating with the OceanBase Vector database",
default=None,
)

OCEANBASE_VECTOR_DATABASE: Optional[str] = Field(
description="Name of the OceanBase Vector database to connect to",
default=None,
)
2 changes: 2 additions & 0 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def get(self):
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
| VectorType.OCEANBASE
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
Expand Down Expand Up @@ -669,6 +670,7 @@ def get(self, vector_type):
| VectorType.BAIDU
| VectorType.VIKINGDB
| VectorType.UPSTASH
| VectorType.OCEANBASE
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
Expand Down
Empty file.
209 changes: 209 additions & 0 deletions api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import json
import logging
import math
from typing import Any

from pydantic import BaseModel, model_validator
from pyobvector import VECTOR, ObVecClient
from sqlalchemy import JSON, Column, String, func
from sqlalchemy.dialects.mysql import LONGTEXT

from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

logger = logging.getLogger(__name__)

DEFAULT_OCEANBASE_HNSW_BUILD_PARAM = {"M": 16, "efConstruction": 256}
DEFAULT_OCEANBASE_HNSW_SEARCH_PARAM = {"efSearch": 64}
OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE = "HNSW"
DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE = "l2"


class OceanBaseVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str

@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config OCEANBASE_VECTOR_HOST is required")
if not values["port"]:
raise ValueError("config OCEANBASE_VECTOR_PORT is required")
if not values["user"]:
raise ValueError("config OCEANBASE_VECTOR_USER is required")
if not values["database"]:
raise ValueError("config OCEANBASE_VECTOR_DATABASE is required")
return values


class OceanBaseVector(BaseVector):
def __init__(self, collection_name: str, config: OceanBaseVectorConfig):
super().__init__(collection_name)
self._config = config
self._hnsw_ef_search = -1
self._client = ObVecClient(
uri=f"{self._config.host}:{self._config.port}",
user=self._config.user,
password=self._config.password,
db_name=self._config.database,
)

def get_type(self) -> str:
return VectorType.OCEANBASE

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._vec_dim = len(embeddings[0])
self._create_collection()
self.add_texts(texts, embeddings)

def _create_collection(self) -> None:
lock_name = "vector_indexing_lock_" + self._collection_name
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_" + self._collection_name
if redis_client.get(collection_exist_cache_key):
return

if self._client.check_table_exists(self._collection_name):
return

self.delete()

cols = [
Column("id", String(36), primary_key=True, autoincrement=False),
Column("vector", VECTOR(self._vec_dim)),
Column("text", LONGTEXT),
Column("metadata", JSON),
]
vidx_params = self._client.prepare_index_params()
vidx_params.add_index(
field_name="vector",
index_type=OCEANBASE_SUPPORTED_VECTOR_INDEX_TYPE,
index_name="vector_index",
metric_type=DEFAULT_OCEANBASE_VECTOR_METRIC_TYPE,
params=DEFAULT_OCEANBASE_HNSW_BUILD_PARAM,
)

self._client.create_table_with_index_params(
table_name=self._collection_name,
columns=cols,
vidxs=vidx_params,
)
vals = []
params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'")
for row in params:
val = int(row[6])
vals.append(val)
if len(vals) == 0:
print("ob_vector_memory_limit_percentage not found in parameters.")
exit(1)
if any(val == 0 for val in vals):
try:
self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30")
except Exception as e:
raise Exception(
"Failed to set ob_vector_memory_limit_percentage. "
+ "Maybe the database user has insufficient privilege.",
e,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
ids = self._get_uuids(documents)
for id, doc, emb in zip(ids, documents, embeddings):
self._client.insert(
table_name=self._collection_name,
data={
"id": id,
"vector": emb,
"text": doc.page_content,
"metadata": doc.metadata,
},
)

def text_exists(self, id: str) -> bool:
cur = self._client.get(table_name=self._collection_name, id=id)
return cur.rowcount != 0

def delete_by_ids(self, ids: list[str]) -> None:
self._client.delete(table_name=self._collection_name, ids=ids)

def get_ids_by_metadata_field(self, key: str, value: str) -> list[str]:
cur = self._client.get(
table_name=self._collection_name,
where_clause=f"metadata->>'$.{key}' = '{value}'",
output_column_name=["id"],
)
return [row[0] for row in cur]

def delete_by_metadata_field(self, key: str, value: str) -> None:
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
if ef_search != self._hnsw_ef_search:
self._client.set_ob_hnsw_ef_search(ef_search)
self._hnsw_ef_search = ef_search
topk = kwargs.get("top_k", 10)
cur = self._client.ann_search(
table_name=self._collection_name,
vec_column_name="vector",
vec_data=query_vector,
topk=topk,
distance_func=func.l2_distance,
output_column_names=["text", "metadata"],
with_dist=True,
)
docs = []
for text, metadata, distance in cur:
metadata = json.loads(metadata)
metadata["score"] = 1 - distance / math.sqrt(2)
docs.append(
Document(
page_content=text,
metadata=metadata,
)
)
return docs

def delete(self) -> None:
self._client.drop_table_if_exist(self._collection_name)


class OceanBaseVectorFactory(AbstractVectorFactory):
def init_vector(
self,
dataset: Dataset,
attributes: list,
embeddings: Embeddings,
) -> BaseVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.OCEANBASE, collection_name))
return OceanBaseVector(
collection_name,
OceanBaseVectorConfig(
host=dify_config.OCEANBASE_VECTOR_HOST,
port=dify_config.OCEANBASE_VECTOR_PORT,
user=dify_config.OCEANBASE_VECTOR_USER,
password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""),
database=dify_config.OCEANBASE_VECTOR_DATABASE,
),
)
4 changes: 4 additions & 0 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory

return TidbOnQdrantVectorFactory
case VectorType.OCEANBASE:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory

return OceanBaseVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

Expand Down
1 change: 1 addition & 0 deletions api/core/rag/datasource/vdb/vector_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ class VectorType(str, Enum):
VIKINGDB = "vikingdb"
UPSTASH = "upstash"
TIDB_ON_QDRANT = "tidb_on_qdrant"
OCEANBASE = "oceanbase"
Loading