Skip to content

Tool retrieval 2 #417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Python interpreter tool.
- Auto alembic upgrade.
- New utterances in false positive semantic route.
- Tool pre-selection.
- Ephys tool.

## Changed
Expand Down
1 change: 1 addition & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ NEUROAGENT_TOOLS__ENTITYCORE__URL=
NEUROAGENT_TOOLS__BLUENAAS_URL=
NEUROAGENT_TOOLS__THUMBNAIL_GENERATION__URL=
NEUROAGENT_TOOLS__WHITELISTED_TOOL_REGEX=
NEUROAGENT_TOOLS__MIN_TOOL_SELECTION=


NEUROAGENT_LLM__SUGGESTION_MODEL=
Expand Down
41 changes: 41 additions & 0 deletions backend/alembic/versions/825dd56cabf4_add_tool_selection_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Add tool_selection table

Revision ID: 825dd56cabf4
Revises: 529e44b33a67
Create Date: 2025-07-09 16:05:31.480188

"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "825dd56cabf4"
down_revision: Union[str, None] = "529e44b33a67"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"tool_selection",
sa.Column("id", sa.String(), nullable=False),
sa.Column("selected_tools", sa.String(), nullable=False),
sa.Column("message_id", sa.UUID(), nullable=False),
sa.ForeignKeyConstraint(
["message_id"],
["messages.message_id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("tool_selection")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ preview = true

[tool.ruff]
line-length = 88
target-version = "py310"
target-version = "py311"
exclude = ["src/neuroagent/tools/autogenerated_types"]

[tool.ruff.lint]
Expand Down
88 changes: 87 additions & 1 deletion backend/src/neuroagent/app/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import re
import uuid
from pathlib import Path
from typing import Any, Sequence
from typing import Any, Literal, Sequence

import yaml
from fastapi import HTTPException
from openai import AsyncOpenAI
from pydantic import BaseModel, ConfigDict, Field
from redis import asyncio as aioredis
from semantic_router import Route
Expand All @@ -34,6 +35,7 @@
ToolCallVercel,
)
from neuroagent.schemas import EmbeddedBrainRegions
from neuroagent.tools.base_tool import BaseTool

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -469,3 +471,87 @@ def parse_redis_data(
remaining=remaining,
reset_in=reset_in,
)


async def filter_tools_by_conversation(
openai_messages: list[dict[str, str]],
tool_list: list[type[BaseTool]],
user_content: str,
openai_client: AsyncOpenAI,
min_tool_selection: int,
) -> list[type[BaseTool]]:
"""
Filter tools based on conversation relevance.

Args:
openai_messages: List of OpenAI formatted messages
tool_list: List of available tools
user_content: Current user message content
openai_client: OpenAI client instance
settings: Application settings

Returns
-------
List of filtered tools relevant to the conversation
"""
if not tool_list:
return []

if len(tool_list) < min_tool_selection:
return tool_list

# Remove the content of tool responses to save tokens
for message in openai_messages:
if message["role"] == "tool":
message["content"] = "..."

# Add the current user message
openai_messages.append({"role": "user", "content": user_content})

system_prompt = f"""TASK: Filter tools for AI agent based on conversation relevance.

INSTRUCTIONS:
1. Analyze the conversation to identify required capabilities
2. Select at least {min_tool_selection} of the most relevant tools by name only
3. BIAS TOWARD INCLUSION: If uncertain about a tool's relevance, include it - better to provide too many tools than too few
4. Only exclude tools that are clearly irrelevant to the conversation
5. Output format: comma-separated list of tool names
6. Do not respond to user queries - only filter tools
7. Each tool must be selected only once.

OUTPUT: [tool_name1, tool_name2, ...]

AVAILABLE TOOLS:
{chr(10).join(f"{tool.name}: {tool.description}" for tool in tool_list)}
"""

tool_names = [tool.name for tool in tool_list]
TOOL_NAMES_LITERAL = Literal[*tool_names] # type: ignore

class ToolSelection(BaseModel):
"""Data class for tool selection by an LLM."""

selected_tools: list[TOOL_NAMES_LITERAL] = Field(
min_length=min_tool_selection,
description=f"List of selected tool names, minimum {min_tool_selection} items. Must contain all of the tools relevant to the conversation.",
)

try:
response = await openai_client.beta.chat.completions.parse(
messages=[{"role": "system", "content": system_prompt}, *openai_messages], # type: ignore
model="gpt-4o-mini",
response_format=ToolSelection,
)

if response.choices[0].message.parsed:
selected_tools = response.choices[0].message.parsed.selected_tools
logger.info(
f"QUERY: {user_content}, #TOOLS: {len(selected_tools)}, SELECTED TOOLS: {selected_tools}"
)
return [tool for tool in tool_list if tool.name in selected_tools]
else:
logger.warning("No parsed response from OpenAI, returning empty list")
return []
except Exception as e:
logger.error(f"Error filtering tools: {e}")
return []
1 change: 1 addition & 0 deletions backend/src/neuroagent/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class SettingsTools(BaseModel):
entitycore: SettingsEntityCore = SettingsEntityCore()
web_search: SettingsWebSearch = SettingsWebSearch()
thumbnail_generation: SettingsThumbnailGeneration = SettingsThumbnailGeneration()
min_tool_selection: int = 10
whitelisted_tool_regex: str | None = None

model_config = ConfigDict(frozen=True)
Expand Down
14 changes: 14 additions & 0 deletions backend/src/neuroagent/app/database/sql_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class Messages(Base):
tool_calls: Mapped[list["ToolCalls"]] = relationship(
"ToolCalls", back_populates="message", cascade="all, delete-orphan"
)
selected_tools: Mapped[list["ToolCalls"]] = relationship("ToolSelection")


class ToolCalls(Base):
Expand All @@ -88,3 +89,16 @@ class ToolCalls(Base):
UUID, ForeignKey("messages.message_id")
)
message: Mapped[Messages] = relationship("Messages", back_populates="tool_calls")


class ToolSelection(Base):
"""SQL table used for storing the tool selected for a query."""

__tablename__ = "tool_selection"
id: Mapped[uuid.UUID] = mapped_column(
UUID, primary_key=True, default=lambda: uuid.uuid4()
)
selected_tools: Mapped[str] = mapped_column(String, nullable=False)
message_id: Mapped[uuid.UUID] = mapped_column(
UUID, ForeignKey("messages.message_id")
)
89 changes: 60 additions & 29 deletions backend/src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from starlette.status import HTTP_401_UNAUTHORIZED

from neuroagent.agent_routine import AgentsRoutine
from neuroagent.app.app_utils import validate_project
from neuroagent.app.app_utils import filter_tools_by_conversation, validate_project
from neuroagent.app.config import Settings
from neuroagent.app.database.sql_schemas import Threads
from neuroagent.app.schemas import OpenRouterModelResponse, UserInfo
Expand Down Expand Up @@ -102,6 +102,7 @@
WebSearchTool,
)
from neuroagent.tools.base_tool import BaseTool
from neuroagent.utils import messages_to_openai_content

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -244,6 +245,33 @@ async def get_user_info(
raise HTTPException(status_code=404, detail="User info url not provided.")


async def get_thread(
user_info: Annotated[UserInfo, Depends(get_user_info)],
thread_id: str,
session: Annotated[AsyncSession, Depends(get_session)],
) -> Threads:
"""Check if the current thread / user matches."""
thread_result = await session.execute(
select(Threads).where(
Threads.user_id == user_info.sub, Threads.thread_id == thread_id
)
)
thread = thread_result.scalars().one_or_none()
if not thread:
raise HTTPException(
status_code=404,
detail={
"detail": "Thread not found.",
},
)
validate_project(
groups=user_info.groups,
virtual_lab_id=thread.vlab_id,
project_id=thread.project_id,
)
return thread


def get_mcp_client(request: Request) -> MCPClient | None:
"""Get the MCP client from the app state."""
if request.app.state.mcp_client is None:
Expand Down Expand Up @@ -415,6 +443,36 @@ async def get_selected_tools(
return selected_tools


async def filtered_tools(
request: Request,
thread: Annotated[Threads, Depends(get_thread)],
tool_list: Annotated[list[type[BaseTool]], Depends(get_selected_tools)],
openai_client: Annotated[AsyncOpenAI, Depends(get_openai_client)],
settings: Annotated[Settings, Depends(get_settings)],
) -> list[type[BaseTool]]:
"""Based on the current conversation, select relevant tools."""
if request.method == "GET":
return tool_list

# Awaiting here makes downstream calls already loaded so no performance issue
messages = await thread.awaitable_attrs.messages
if not tool_list:
return []

openai_messages = await messages_to_openai_content(messages)

body = await request.json()
user_content = body["content"]

return await filter_tools_by_conversation(
openai_messages=openai_messages,
tool_list=tool_list,
user_content=user_content,
openai_client=openai_client,
min_tool_selection=settings.tools.min_tool_selection,
)


@cache
def get_rules_dir() -> Path:
"""Get the path to the rules directory."""
Expand Down Expand Up @@ -476,7 +534,7 @@ def get_system_prompt(rules_dir: Annotated[Path, Depends(get_rules_dir)]) -> str


def get_starting_agent(
tool_list: Annotated[list[type[BaseTool]], Depends(get_selected_tools)],
tool_list: Annotated[list[type[BaseTool]], Depends(filtered_tools)],
system_prompt: Annotated[str, Depends(get_system_prompt)],
) -> Agent:
"""Get the starting agent."""
Expand All @@ -488,33 +546,6 @@ def get_starting_agent(
return agent


async def get_thread(
user_info: Annotated[UserInfo, Depends(get_user_info)],
thread_id: str,
session: Annotated[AsyncSession, Depends(get_session)],
) -> Threads:
"""Check if the current thread / user matches."""
thread_result = await session.execute(
select(Threads).where(
Threads.user_id == user_info.sub, Threads.thread_id == thread_id
)
)
thread = thread_result.scalars().one_or_none()
if not thread:
raise HTTPException(
status_code=404,
detail={
"detail": "Thread not found.",
},
)
validate_project(
groups=user_info.groups,
virtual_lab_id=thread.vlab_id,
project_id=thread.project_id,
)
return thread


def get_semantic_routes(request: Request) -> SemanticRouter | None:
"""Get the semantic route object for basic guardrails."""
return request.app.state.semantic_router
Expand Down
20 changes: 6 additions & 14 deletions backend/src/neuroagent/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,7 @@
validate_project,
)
from neuroagent.app.config import Settings
from neuroagent.app.database.sql_schemas import (
Entity,
Messages,
Threads,
)
from neuroagent.app.database.sql_schemas import Entity, Messages, Threads, ToolSelection
from neuroagent.app.dependencies import (
get_accounting_session_factory,
get_agents_routine,
Expand Down Expand Up @@ -304,15 +300,8 @@ async def stream_chat_agent(
agent.model = agent.model.removeprefix("openai/")
agents_routine.client = openai_client

messages: list[Messages] = await thread.awaitable_attrs.messages
# Since the session is not reinstantiated in stream.py
# we need to lazy load the tool_calls in advance since in
# any case they will be needed to convert the db schema
# to OpenAI messages
for msg in messages:
if msg.entity == Entity.AI_TOOL:
# This awaits the lazy loading, ensuring tool_calls is populated now.
await msg.awaitable_attrs.tool_calls
# No need to await since it has been awaited in tool filtering dependency
messages: list[Messages] = thread.messages

if (
not messages
Expand All @@ -326,6 +315,9 @@ async def stream_chat_agent(
content=json.dumps({"role": "user", "content": user_request.content}),
is_complete=True,
model=None,
selected_tools=[
ToolSelection(selected_tools=tool.name) for tool in agent.tools
],
)
)

Expand Down
5 changes: 2 additions & 3 deletions backend/src/neuroagent/rules/ai-message-formatting.mdc
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ description: AI-message formatting guidelines
---
# Message formatting guidelines

## Structure and Organization
- Use hierarchical headers (##, ###) to organize information by categories and subcategories
# AI Message Formatting Guidelines
- Always use rich markdown formatting.

## Content Structure Guidelines

- Use hierarchical headers (##, ###) to organize information by categories and subcategories
- Create detailed subsections for each specific area or entity
- Group related items under appropriate subheaders
- Use descriptive section titles that specify the exact areas or categories
Expand Down
Loading