-
Notifications
You must be signed in to change notification settings - Fork 1
Tool retrieval 2 #417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+291
−77
Merged
Tool retrieval 2 #417
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
8168898
Temp
WonderPG 0b5ada3
Edit rule
WonderPG 825b9a0
Merge branch 'main' into tool-retrieval-2
WonderPG 9597d75
Finish
WonderPG 96410cc
Add env var
WonderPG 648c266
Final fixes
WonderPG c855e19
Update github action
WonderPG e837864
Update ruff linting
WonderPG bd34212
Fix tests
WonderPG acf1d05
Merge main
WonderPG dfcc247
Default to 10 tools min
WonderPG b10e36a
Ignore dependency if tool list is empty
WonderPG adf0208
Small fix
WonderPG 720b663
Split logic into helper + dependency
WonderPG e8acf49
Merge branch 'main' into tool-retrieval-2
WonderPG 5ffd05e
Create tool_selection table and store tool selection in db
WonderPG 86e7411
Fix alembic types
WonderPG d634f2c
Turn docstring into numpy doc
WonderPG 5f2edae
Merge branch 'main' into tool-retrieval-2
WonderPG f54379f
Turn logger into debug
WonderPG 92881d8
Remove duplicated if
WonderPG 2fda127
Return if len(tool_list) == min_selection
WonderPG 06a18a2
Unify if statements
WonderPG 4e25024
Force min number of tools to be positive
WonderPG e0e36ca
revert
WonderPG File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,14 +5,15 @@ | |
from datetime import datetime, timezone | ||
from functools import cache | ||
from pathlib import Path | ||
from typing import Annotated, Any, AsyncIterator | ||
from typing import Annotated, Any, AsyncIterator, Literal | ||
|
||
import boto3 | ||
from fastapi import Depends, HTTPException, Request | ||
from fastapi.security import HTTPBearer | ||
from httpx import AsyncClient, HTTPStatusError, get | ||
from obp_accounting_sdk import AsyncAccountingSessionFactory | ||
from openai import AsyncOpenAI | ||
from pydantic import BaseModel, Field | ||
from redis import asyncio as aioredis | ||
from semantic_router.routers import SemanticRouter | ||
from sqlalchemy import select | ||
|
@@ -102,6 +103,7 @@ | |
WebSearchTool, | ||
) | ||
from neuroagent.tools.base_tool import BaseTool | ||
from neuroagent.utils import messages_to_openai_content | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -244,6 +246,33 @@ async def get_user_info( | |
raise HTTPException(status_code=404, detail="User info url not provided.") | ||
|
||
|
||
async def get_thread( | ||
user_info: Annotated[UserInfo, Depends(get_user_info)], | ||
thread_id: str, | ||
session: Annotated[AsyncSession, Depends(get_session)], | ||
) -> Threads: | ||
"""Check if the current thread / user matches.""" | ||
thread_result = await session.execute( | ||
select(Threads).where( | ||
Threads.user_id == user_info.sub, Threads.thread_id == thread_id | ||
) | ||
) | ||
thread = thread_result.scalars().one_or_none() | ||
if not thread: | ||
raise HTTPException( | ||
status_code=404, | ||
detail={ | ||
"detail": "Thread not found.", | ||
}, | ||
) | ||
validate_project( | ||
groups=user_info.groups, | ||
virtual_lab_id=thread.vlab_id, | ||
project_id=thread.project_id, | ||
) | ||
return thread | ||
|
||
|
||
def get_mcp_client(request: Request) -> MCPClient | None: | ||
"""Get the MCP client from the app state.""" | ||
if request.app.state.mcp_client is None: | ||
|
@@ -415,6 +444,72 @@ async def get_selected_tools( | |
return selected_tools | ||
|
||
|
||
async def filtered_tools( | ||
request: Request, | ||
thread: Annotated[Threads, Depends(get_thread)], | ||
tool_list: Annotated[list[type[BaseTool]], Depends(get_selected_tools)], | ||
openai_client: Annotated[AsyncOpenAI, Depends(get_openai_client)], | ||
settings: Annotated[Settings, Depends(get_settings)], | ||
) -> list[type[BaseTool]]: | ||
"""Based on the current conversation, select relevant tools.""" | ||
if request.method == "GET": | ||
return tool_list | ||
|
||
# Awaiting here makes downstream calls already loaded so no performance issue | ||
messages = await thread.awaitable_attrs.messages | ||
openai_messages = await messages_to_openai_content(messages) | ||
|
||
# Remove the content of tool responses to save tokens | ||
for message in openai_messages: | ||
if message["role"] == "tool": | ||
message["content"] = "..." | ||
|
||
body = await request.json() | ||
openai_messages.append({"role": "user", "content": body["content"]}) | ||
|
||
system_prompt = f"""TASK: Filter tools for AI agent based on conversation relevance. | ||
|
||
AVAILABLE TOOLS: | ||
{chr(10).join(f"{tool.name}: {tool.description}" for tool in tool_list)} | ||
|
||
INSTRUCTIONS: | ||
1. Analyze the conversation to identify required capabilities | ||
2. Select at least 5 of the most relevant tools by name only | ||
3. BIAS TOWARD INCLUSION: If uncertain about a tool's relevance, include it - better to provide too many tools than too few | ||
4. Only exclude tools that are clearly irrelevant to the conversation | ||
5. Output format: comma-separated list of tool names | ||
6. Do not respond to user queries - only filter tools | ||
|
||
OUTPUT: [tool_name1, tool_name2, ...]""" | ||
|
||
tool_names = [tool.name for tool in tool_list] | ||
TOOL_NAMES_LITERAL = Literal[*tool_names] if len(tool_names) > 0 else str | ||
|
||
class ToolSelection(BaseModel): | ||
"""Data class for tool selection by an LLM.""" | ||
|
||
selected_tools: list[TOOL_NAMES_LITERAL] = Field( # type: ignore | ||
min_length=min(len(tool_names), settings.tools.min_tool_selection), | ||
description=f"List of selected tool names, minimum {min(len(tool_names), settings.tools.min_tool_selection)} items. Must contain all of the tools relevant to the conversation.", | ||
) | ||
|
||
# Rest of your code remains the same | ||
response = await openai_client.beta.chat.completions.parse( | ||
messages=[{"role": "system", "content": system_prompt}, *openai_messages], # type: ignore | ||
model="gpt-4o-mini", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be turned into an env var but we have enough already. Please let me know if you would prefer an env var. |
||
response_format=ToolSelection, | ||
) | ||
|
||
if response.choices[0].message.parsed: | ||
selected_tools = response.choices[0].message.parsed.selected_tools | ||
logger.info( | ||
f"QUERY: {body['content']}, #TOOLS: {len(selected_tools)}, SELECTED TOOLS: {selected_tools}" | ||
) | ||
WonderPG marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return [tool for tool in tool_list if tool.name in selected_tools] | ||
else: | ||
return [] | ||
jankrepl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@cache | ||
def get_rules_dir() -> Path: | ||
"""Get the path to the rules directory.""" | ||
|
@@ -476,7 +571,7 @@ def get_system_prompt(rules_dir: Annotated[Path, Depends(get_rules_dir)]) -> str | |
|
||
|
||
def get_starting_agent( | ||
tool_list: Annotated[list[type[BaseTool]], Depends(get_selected_tools)], | ||
tool_list: Annotated[list[type[BaseTool]], Depends(filtered_tools)], | ||
system_prompt: Annotated[str, Depends(get_system_prompt)], | ||
) -> Agent: | ||
"""Get the starting agent.""" | ||
|
@@ -488,33 +583,6 @@ def get_starting_agent( | |
return agent | ||
|
||
|
||
async def get_thread( | ||
user_info: Annotated[UserInfo, Depends(get_user_info)], | ||
thread_id: str, | ||
session: Annotated[AsyncSession, Depends(get_session)], | ||
) -> Threads: | ||
"""Check if the current thread / user matches.""" | ||
thread_result = await session.execute( | ||
select(Threads).where( | ||
Threads.user_id == user_info.sub, Threads.thread_id == thread_id | ||
) | ||
) | ||
thread = thread_result.scalars().one_or_none() | ||
if not thread: | ||
raise HTTPException( | ||
status_code=404, | ||
detail={ | ||
"detail": "Thread not found.", | ||
}, | ||
) | ||
validate_project( | ||
groups=user_info.groups, | ||
virtual_lab_id=thread.vlab_id, | ||
project_id=thread.project_id, | ||
) | ||
return thread | ||
|
||
|
||
def get_semantic_routes(request: Request) -> SemanticRouter | None: | ||
"""Get the semantic route object for basic guardrails.""" | ||
return request.app.state.semantic_router | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.