Skip to content

Fix #12448 - update bedrock retrieve tool, support hybrid search type and re… #12446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,38 @@ class BedrockRetrieveTool(BuiltinTool):
topk: int = None

def _bedrock_retrieve(
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
self,
query_input: str,
knowledge_base_id: str,
num_results: int,
search_type: str,
rerank_model_id: str,
metadata_filter: Optional[dict] = None,
):
try:
retrieval_query = {"text": query_input}

retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
if search_type not in ["HYBRID", "SEMANTIC"]:
raise RuntimeException("search_type should be HYBRID or SEMANTIC")

retrieval_configuration = {
"vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type}
}

if rerank_model_id != "default":
model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}"
rerankingConfiguration = {
"bedrockRerankingConfiguration": {
"numberOfRerankedResults": num_results,
"modelConfiguration": {"modelArn": model_for_rerank_arn},
},
"type": "BEDROCK_RERANKING_MODEL",
}

# Add metadata filter to retrieval configuration if present
retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration
retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5

# 如果有元数据过滤条件,则添加到检索配置中
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter

Expand Down Expand Up @@ -77,15 +101,20 @@ def _invoke(
if not query:
return self.create_text_message("Please input query")

# Get metadata filter conditions (if they exist)
# 获取元数据过滤条件(如果存在)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None

search_type = tool_parameters.get("search_type")
rerank_model_id = tool_parameters.get("rerank_model_id")

line = 4
retrieved_docs = self._bedrock_retrieve(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
search_type=search_type,
rerank_model_id=rerank_model_id,
metadata_filter=metadata_filter,
)

Expand All @@ -109,7 +138,7 @@ def validate_parameters(self, parameters: dict[str, Any]) -> None:
if not parameters.get("query"):
raise ValueError("query is required")

# Optional: Validate if metadata filter is a valid JSON string (if provided)
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")
51 changes: 51 additions & 0 deletions api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,57 @@ parameters:
max: 10
default: 5

- name: search_type
type: select
required: false
label:
en_US: search type
zh_Hans: 搜索类型
pt_BR: search type
human_description:
en_US: search type
zh_Hans: 搜索类型
pt_BR: search type
llm_description: search type
default: SEMANTIC
options:
- value: SEMANTIC
label:
en_US: SEMANTIC
zh_Hans: 语义搜索
- value: HYBRID
label:
en_US: HYBRID
zh_Hans: 混合搜索
form: form

- name: rerank_model_id
type: select
required: false
label:
en_US: rerank model id
zh_Hans: 重拍模型ID
pt_BR: rerank model id
human_description:
en_US: rerank model id
zh_Hans: 重拍模型ID
pt_BR: rerank model id
llm_description: rerank model id
options:
- value: default
label:
en_US: default
zh_Hans: 默认
- value: cohere.rerank-v3-5:0
label:
en_US: cohere.rerank-v3-5:0
zh_Hans: cohere.rerank-v3-5:0
- value: amazon.rerank-v1:0
label:
en_US: amazon.rerank-v1:0
zh_Hans: amazon.rerank-v1:0
form: form

- name: aws_region
type: string
required: false
Expand Down
Loading