diff --git a/CHANGELOG.md b/CHANGELOG.md index ab9949bd..e2a1b774 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Python interpreter tool. - Auto alembic upgrade. - New utterances in false positive semantic route. +- Tool pre-selection. - Ephys tool. ### Changed diff --git a/backend/.env.example b/backend/.env.example index f3fac9c3..9152aa93 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -22,6 +22,7 @@ NEUROAGENT_TOOLS__ENTITYCORE__URL= NEUROAGENT_TOOLS__BLUENAAS_URL= NEUROAGENT_TOOLS__THUMBNAIL_GENERATION__URL= NEUROAGENT_TOOLS__WHITELISTED_TOOL_REGEX= +NEUROAGENT_TOOLS__MIN_TOOL_SELECTION= NEUROAGENT_LLM__SUGGESTION_MODEL= diff --git a/backend/alembic/versions/dde4f8453a14_add_tool_selection_table.py b/backend/alembic/versions/dde4f8453a14_add_tool_selection_table.py new file mode 100644 index 00000000..55f2f968 --- /dev/null +++ b/backend/alembic/versions/dde4f8453a14_add_tool_selection_table.py @@ -0,0 +1,41 @@ +"""Add tool_selection table + +Revision ID: dde4f8453a14 +Revises: 529e44b33a67 +Create Date: 2025-07-09 16:11:54.496827 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "dde4f8453a14" +down_revision: Union[str, None] = "529e44b33a67" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "tool_selection", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("selected_tools", sa.String(), nullable=False), + sa.Column("message_id", sa.UUID(), nullable=False), + sa.ForeignKeyConstraint( + ["message_id"], + ["messages.message_id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("tool_selection") + # ### end Alembic commands ### diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 11b67588..9c7306c5 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -69,7 +69,7 @@ preview = true [tool.ruff] line-length = 88 -target-version = "py310" +target-version = "py311" exclude = ["src/neuroagent/tools/autogenerated_types"] [tool.ruff.lint] diff --git a/backend/src/neuroagent/app/app_utils.py b/backend/src/neuroagent/app/app_utils.py index d09261aa..46360ec9 100644 --- a/backend/src/neuroagent/app/app_utils.py +++ b/backend/src/neuroagent/app/app_utils.py @@ -5,10 +5,11 @@ import re import uuid from pathlib import Path -from typing import Any, Sequence +from typing import Any, Literal, Sequence import yaml from fastapi import HTTPException +from openai import AsyncOpenAI from pydantic import BaseModel, ConfigDict, Field from redis import asyncio as aioredis from semantic_router import Route @@ -34,6 +35,7 @@ ToolCallVercel, ) from neuroagent.schemas import EmbeddedBrainRegions +from neuroagent.tools.base_tool import BaseTool logger = logging.getLogger(__name__) @@ -469,3 +471,90 @@ def parse_redis_data( remaining=remaining, reset_in=reset_in, ) + + +async def filter_tools_by_conversation( + openai_messages: list[dict[str, str]], + tool_list: list[type[BaseTool]], + user_content: str, + openai_client: AsyncOpenAI, + min_tool_selection: int, +) -> list[type[BaseTool]]: + """ + Filter tools based on conversation relevance. + + Parameters + ---------- + openai_messages: + List of OpenAI formatted messages + tool_list: + List of available tools + user_content: + Current user message content + openai_client: + OpenAI client instance + min_tool_selection: + Minimum numbers of tools the LLM should select + + Returns + ------- + List of filtered tools relevant to the conversation + """ + if len(tool_list) <= min_tool_selection: + return tool_list + + # Remove the content of tool responses to save tokens + for message in openai_messages: + if message["role"] == "tool": + message["content"] = "..." + + # Add the current user message + openai_messages.append({"role": "user", "content": user_content}) + + system_prompt = f"""TASK: Filter tools for AI agent based on conversation relevance. + +INSTRUCTIONS: +1. Analyze the conversation to identify required capabilities +2. Select at least {min_tool_selection} of the most relevant tools by name only +3. BIAS TOWARD INCLUSION: If uncertain about a tool's relevance, include it - better to provide too many tools than too few +4. Only exclude tools that are clearly irrelevant to the conversation +5. Output format: comma-separated list of tool names +6. Do not respond to user queries - only filter tools +7. Each tool must be selected only once. + +OUTPUT: [tool_name1, tool_name2, ...] + +AVAILABLE TOOLS: +{chr(10).join(f"{tool.name}: {tool.description}" for tool in tool_list)} +""" + + tool_names = [tool.name for tool in tool_list] + TOOL_NAMES_LITERAL = Literal[*tool_names] # type: ignore + + class ToolSelection(BaseModel): + """Data class for tool selection by an LLM.""" + + selected_tools: list[TOOL_NAMES_LITERAL] = Field( + min_length=min_tool_selection, + description=f"List of selected tool names, minimum {min_tool_selection} items. Must contain all of the tools relevant to the conversation.", + ) + + try: + response = await openai_client.beta.chat.completions.parse( + messages=[{"role": "system", "content": system_prompt}, *openai_messages], # type: ignore + model="gpt-4o-mini", + response_format=ToolSelection, + ) + + if response.choices[0].message.parsed: + selected_tools = response.choices[0].message.parsed.selected_tools + logger.debug( + f"QUERY: {user_content}, #TOOLS: {len(selected_tools)}, SELECTED TOOLS: {selected_tools}" + ) + return [tool for tool in tool_list if tool.name in selected_tools] + else: + logger.warning("No parsed response from OpenAI, returning empty list") + return [] + except Exception as e: + logger.error(f"Error filtering tools: {e}") + return [] diff --git a/backend/src/neuroagent/app/config.py b/backend/src/neuroagent/app/config.py index c826bb64..8f45bac4 100644 --- a/backend/src/neuroagent/app/config.py +++ b/backend/src/neuroagent/app/config.py @@ -7,7 +7,7 @@ from typing import Any, Literal from dotenv import dotenv_values -from pydantic import BaseModel, ConfigDict, SecretStr, model_validator +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict logger = logging.getLogger(__name__) @@ -117,6 +117,7 @@ class SettingsTools(BaseModel): entitycore: SettingsEntityCore = SettingsEntityCore() web_search: SettingsWebSearch = SettingsWebSearch() thumbnail_generation: SettingsThumbnailGeneration = SettingsThumbnailGeneration() + min_tool_selection: int = Field(default=10, ge=0) whitelisted_tool_regex: str | None = None model_config = ConfigDict(frozen=True) diff --git a/backend/src/neuroagent/app/database/sql_schemas.py b/backend/src/neuroagent/app/database/sql_schemas.py index f4f54db2..9b4ca13c 100644 --- a/backend/src/neuroagent/app/database/sql_schemas.py +++ b/backend/src/neuroagent/app/database/sql_schemas.py @@ -73,6 +73,7 @@ class Messages(Base): tool_calls: Mapped[list["ToolCalls"]] = relationship( "ToolCalls", back_populates="message", cascade="all, delete-orphan" ) + selected_tools: Mapped[list["ToolCalls"]] = relationship("ToolSelection") class ToolCalls(Base): @@ -88,3 +89,16 @@ class ToolCalls(Base): UUID, ForeignKey("messages.message_id") ) message: Mapped[Messages] = relationship("Messages", back_populates="tool_calls") + + +class ToolSelection(Base): + """SQL table used for storing the tool selected for a query.""" + + __tablename__ = "tool_selection" + id: Mapped[uuid.UUID] = mapped_column( + UUID, primary_key=True, default=lambda: uuid.uuid4() + ) + selected_tools: Mapped[str] = mapped_column(String, nullable=False) + message_id: Mapped[uuid.UUID] = mapped_column( + UUID, ForeignKey("messages.message_id") + ) diff --git a/backend/src/neuroagent/app/dependencies.py b/backend/src/neuroagent/app/dependencies.py index 36da8844..3caa00da 100644 --- a/backend/src/neuroagent/app/dependencies.py +++ b/backend/src/neuroagent/app/dependencies.py @@ -20,7 +20,7 @@ from starlette.status import HTTP_401_UNAUTHORIZED from neuroagent.agent_routine import AgentsRoutine -from neuroagent.app.app_utils import validate_project +from neuroagent.app.app_utils import filter_tools_by_conversation, validate_project from neuroagent.app.config import Settings from neuroagent.app.database.sql_schemas import Threads from neuroagent.app.schemas import OpenRouterModelResponse, UserInfo @@ -102,6 +102,7 @@ WebSearchTool, ) from neuroagent.tools.base_tool import BaseTool +from neuroagent.utils import messages_to_openai_content logger = logging.getLogger(__name__) @@ -244,6 +245,33 @@ async def get_user_info( raise HTTPException(status_code=404, detail="User info url not provided.") +async def get_thread( + user_info: Annotated[UserInfo, Depends(get_user_info)], + thread_id: str, + session: Annotated[AsyncSession, Depends(get_session)], +) -> Threads: + """Check if the current thread / user matches.""" + thread_result = await session.execute( + select(Threads).where( + Threads.user_id == user_info.sub, Threads.thread_id == thread_id + ) + ) + thread = thread_result.scalars().one_or_none() + if not thread: + raise HTTPException( + status_code=404, + detail={ + "detail": "Thread not found.", + }, + ) + validate_project( + groups=user_info.groups, + virtual_lab_id=thread.vlab_id, + project_id=thread.project_id, + ) + return thread + + def get_mcp_client(request: Request) -> MCPClient | None: """Get the MCP client from the app state.""" if request.app.state.mcp_client is None: @@ -415,6 +443,36 @@ async def get_selected_tools( return selected_tools +async def filtered_tools( + request: Request, + thread: Annotated[Threads, Depends(get_thread)], + tool_list: Annotated[list[type[BaseTool]], Depends(get_selected_tools)], + openai_client: Annotated[AsyncOpenAI, Depends(get_openai_client)], + settings: Annotated[Settings, Depends(get_settings)], +) -> list[type[BaseTool]]: + """Based on the current conversation, select relevant tools.""" + if request.method == "GET": + return tool_list + + # Awaiting here makes downstream calls already loaded so no performance issue + messages = await thread.awaitable_attrs.messages + if not tool_list: + return [] + + openai_messages = await messages_to_openai_content(messages) + + body = await request.json() + user_content = body["content"] + + return await filter_tools_by_conversation( + openai_messages=openai_messages, + tool_list=tool_list, + user_content=user_content, + openai_client=openai_client, + min_tool_selection=settings.tools.min_tool_selection, + ) + + @cache def get_rules_dir() -> Path: """Get the path to the rules directory.""" @@ -476,7 +534,7 @@ def get_system_prompt(rules_dir: Annotated[Path, Depends(get_rules_dir)]) -> str def get_starting_agent( - tool_list: Annotated[list[type[BaseTool]], Depends(get_selected_tools)], + tool_list: Annotated[list[type[BaseTool]], Depends(filtered_tools)], system_prompt: Annotated[str, Depends(get_system_prompt)], settings: Annotated[Settings, Depends(get_settings)], ) -> Agent: @@ -490,33 +548,6 @@ def get_starting_agent( return agent -async def get_thread( - user_info: Annotated[UserInfo, Depends(get_user_info)], - thread_id: str, - session: Annotated[AsyncSession, Depends(get_session)], -) -> Threads: - """Check if the current thread / user matches.""" - thread_result = await session.execute( - select(Threads).where( - Threads.user_id == user_info.sub, Threads.thread_id == thread_id - ) - ) - thread = thread_result.scalars().one_or_none() - if not thread: - raise HTTPException( - status_code=404, - detail={ - "detail": "Thread not found.", - }, - ) - validate_project( - groups=user_info.groups, - virtual_lab_id=thread.vlab_id, - project_id=thread.project_id, - ) - return thread - - def get_semantic_routes(request: Request) -> SemanticRouter | None: """Get the semantic route object for basic guardrails.""" return request.app.state.semantic_router diff --git a/backend/src/neuroagent/app/routers/qa.py b/backend/src/neuroagent/app/routers/qa.py index e0d553c6..29aceea2 100644 --- a/backend/src/neuroagent/app/routers/qa.py +++ b/backend/src/neuroagent/app/routers/qa.py @@ -32,11 +32,7 @@ validate_project, ) from neuroagent.app.config import Settings -from neuroagent.app.database.sql_schemas import ( - Entity, - Messages, - Threads, -) +from neuroagent.app.database.sql_schemas import Entity, Messages, Threads, ToolSelection from neuroagent.app.dependencies import ( get_accounting_session_factory, get_agents_routine, @@ -304,15 +300,8 @@ async def stream_chat_agent( agent.model = agent.model.removeprefix("openai/") agents_routine.client = openai_client - messages: list[Messages] = await thread.awaitable_attrs.messages - # Since the session is not reinstantiated in stream.py - # we need to lazy load the tool_calls in advance since in - # any case they will be needed to convert the db schema - # to OpenAI messages - for msg in messages: - if msg.entity == Entity.AI_TOOL: - # This awaits the lazy loading, ensuring tool_calls is populated now. - await msg.awaitable_attrs.tool_calls + # No need to await since it has been awaited in tool filtering dependency + messages: list[Messages] = thread.messages if ( not messages @@ -326,6 +315,9 @@ async def stream_chat_agent( content=json.dumps({"role": "user", "content": user_request.content}), is_complete=True, model=None, + selected_tools=[ + ToolSelection(selected_tools=tool.name) for tool in agent.tools + ], ) ) diff --git a/backend/src/neuroagent/rules/ai-message-formatting.mdc b/backend/src/neuroagent/rules/ai-message-formatting.mdc index 9336cc79..95da6a2f 100644 --- a/backend/src/neuroagent/rules/ai-message-formatting.mdc +++ b/backend/src/neuroagent/rules/ai-message-formatting.mdc @@ -3,12 +3,11 @@ description: AI-message formatting guidelines --- # Message formatting guidelines -## Structure and Organization -- Use hierarchical headers (##, ###) to organize information by categories and subcategories -# AI Message Formatting Guidelines +- Always use rich markdown formatting. ## Content Structure Guidelines +- Use hierarchical headers (##, ###) to organize information by categories and subcategories - Create detailed subsections for each specific area or entity - Group related items under appropriate subheaders - Use descriptive section titles that specify the exact areas or categories diff --git a/backend/src/neuroagent/tools/autogenerated_types/entitycore.py b/backend/src/neuroagent/tools/autogenerated_types/entitycore.py index d23798d5..ec4bea41 100644 --- a/backend/src/neuroagent/tools/autogenerated_types/entitycore.py +++ b/backend/src/neuroagent/tools/autogenerated_types/entitycore.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: https://staging.openbraininstitute.org/api/entitycore/openapi.json -# timestamp: 2025-07-09T08:29:12+00:00 +# timestamp: 2025-07-09T08:41:55+00:00 from __future__ import annotations diff --git a/backend/src/neuroagent/utils.py b/backend/src/neuroagent/utils.py index c2246a30..da004730 100644 --- a/backend/src/neuroagent/utils.py +++ b/backend/src/neuroagent/utils.py @@ -44,31 +44,7 @@ async def messages_to_openai_content( messages = [] if db_messages: for msg in db_messages: - if msg.content and msg.entity == Entity.AI_TOOL: - # Load the base content - content = json.loads(msg.content) - - # Get the associated tool calls - tool_calls = msg.tool_calls - - # Format it back into the json OpenAI expects - tool_calls_content = [ - { - "function": { - "arguments": tool_call.arguments, - "name": tool_call.name, - }, - "id": tool_call.tool_call_id, - "type": "function", - } - for tool_call in tool_calls - ] - - # Assign it back to the main content - content["tool_calls"] = tool_calls_content - messages.append(content) - else: - messages.append(json.loads(msg.content)) + messages.append(json.loads(msg.content)) return messages diff --git a/backend/tests/app/routers/test_qa.py b/backend/tests/app/routers/test_qa.py index 612ebcd4..c58d0d82 100644 --- a/backend/tests/app/routers/test_qa.py +++ b/backend/tests/app/routers/test_qa.py @@ -1,12 +1,17 @@ +from typing import Annotated from unittest.mock import Mock import pytest +from fastapi import Depends from neuroagent.app.config import Settings +from neuroagent.app.database.sql_schemas import Threads from neuroagent.app.dependencies import ( + filtered_tools, get_agents_routine, get_openai_client, get_settings, + get_thread, ) from neuroagent.app.main import app from neuroagent.app.routers import qa @@ -188,8 +193,17 @@ def test_chat_streamed( rate_limiter={"disabled": True}, ) app.dependency_overrides[get_settings] = lambda: test_settings - agent_routine = Mock() - app.dependency_overrides[get_agents_routine] = lambda: agent_routine + app.dependency_overrides[get_agents_routine] = lambda: Mock() + + async def mock_dependency_with_lazy_loading( + thread: Annotated[Threads, Depends(get_thread)], + ): + # Thread is injected automatically + # Perform the lazy loading + await thread.awaitable_attrs.messages + return [] + + app.dependency_overrides[filtered_tools] = mock_dependency_with_lazy_loading expected_tokens = ( b"Calling tool : resolve_entities_tool with arguments : {brain_region:" diff --git a/backend/tests/app/test_app_utils.py b/backend/tests/app/test_app_utils.py index a653bd7f..1910dd8f 100644 --- a/backend/tests/app/test_app_utils.py +++ b/backend/tests/app/test_app_utils.py @@ -2,13 +2,16 @@ import json from datetime import datetime +from typing import Literal from unittest.mock import AsyncMock, patch from uuid import UUID import pytest from fastapi.exceptions import HTTPException +from pydantic import BaseModel, Field from neuroagent.app.app_utils import ( + filter_tools_by_conversation, format_messages_output, format_messages_vercel, parse_redis_data, @@ -31,6 +34,7 @@ ToolCallVercel, UserInfo, ) +from tests.mock_client import MockOpenAIClient, create_mock_response @pytest.mark.asyncio @@ -565,3 +569,54 @@ def test_various_limit_values(sample_redis_info, limit_value): limit=limit_value, remaining=expected_remaining, reset_in=123 ) assert result == expected + + +@pytest.mark.asyncio +async def test_filter_tools_empty_tool_list(): + """Test that empty tool list returns empty list""" + result = await filter_tools_by_conversation( + openai_messages=[], + tool_list=[], + user_content="test", + openai_client=AsyncMock(), + min_tool_selection=1, + ) + assert result == [] + + +@pytest.mark.asyncio +async def test_filter_tools_successful_selection(get_weather_tool, agent_handoff_tool): + """Test successful tool filtering""" + # Mock OpenAI response + mock_openai_client = MockOpenAIClient() + + class ToolSelection(BaseModel): + """Data class for tool selection by an LLM.""" + + selected_tools: list[Literal["agent_handoff_tool", "get_weather_tool"]] = Field( + min_length=1, + description="List of selected tool names, minimum 1 items. Must contain all of the tools relevant to the conversation.", + ) + + mock_openai_client.set_response( + create_mock_response( + {"role": "assistant", "content": ""}, + structured_output_class=ToolSelection( + selected_tools=["agent_handoff_tool"] + ), + ) + ) + + result = await filter_tools_by_conversation( + openai_messages=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + tool_list=[get_weather_tool, agent_handoff_tool], + user_content="I need help with Agent handoff", + openai_client=mock_openai_client, + min_tool_selection=1, + ) + + assert len(result) == 1 + assert result[0].name == "agent_handoff_tool"