Skip to content

Tool retrieval 2 #417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Python interpreter tool.
- Auto alembic upgrade.
- New utterances in false positive semantic route.
- Tool pre-selection.
- Ephys tool.

## [v0.6.4] - 02.07.2025
Expand Down
1 change: 1 addition & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ NEUROAGENT_TOOLS__OBI_ONE__URL=
NEUROAGENT_TOOLS__ENTITYCORE__URL=
NEUROAGENT_TOOLS__THUMBNAIL_GENERATION__URL=
NEUROAGENT_TOOLS__WHITELISTED_TOOL_REGEX=
NEUROAGENT_TOOLS__MIN_TOOL_SELECTION=


NEUROAGENT_LLM__SUGGESTION_MODEL=
Expand Down
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ preview = true

[tool.ruff]
line-length = 88
target-version = "py310"
target-version = "py311"
exclude = ["src/neuroagent/tools/autogenerated_types"]

[tool.ruff.lint]
Expand Down
1 change: 1 addition & 0 deletions backend/src/neuroagent/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class SettingsTools(BaseModel):
entitycore: SettingsEntityCore = SettingsEntityCore()
web_search: SettingsWebSearch = SettingsWebSearch()
thumbnail_generation: SettingsThumbnailGeneration = SettingsThumbnailGeneration()
min_tool_selection: int = 10
whitelisted_tool_regex: str | None = None

model_config = ConfigDict(frozen=True)
Expand Down
126 changes: 97 additions & 29 deletions backend/src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from datetime import datetime, timezone
from functools import cache
from pathlib import Path
from typing import Annotated, Any, AsyncIterator
from typing import Annotated, Any, AsyncIterator, Literal

import boto3
from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPBearer
from httpx import AsyncClient, HTTPStatusError, get
from obp_accounting_sdk import AsyncAccountingSessionFactory
from openai import AsyncOpenAI
from pydantic import BaseModel, Field
from redis import asyncio as aioredis
from semantic_router.routers import SemanticRouter
from sqlalchemy import select
Expand Down Expand Up @@ -102,6 +103,7 @@
WebSearchTool,
)
from neuroagent.tools.base_tool import BaseTool
from neuroagent.utils import messages_to_openai_content

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -244,6 +246,33 @@ async def get_user_info(
raise HTTPException(status_code=404, detail="User info url not provided.")


async def get_thread(
user_info: Annotated[UserInfo, Depends(get_user_info)],
thread_id: str,
session: Annotated[AsyncSession, Depends(get_session)],
) -> Threads:
"""Check if the current thread / user matches."""
thread_result = await session.execute(
select(Threads).where(
Threads.user_id == user_info.sub, Threads.thread_id == thread_id
)
)
thread = thread_result.scalars().one_or_none()
if not thread:
raise HTTPException(
status_code=404,
detail={
"detail": "Thread not found.",
},
)
validate_project(
groups=user_info.groups,
virtual_lab_id=thread.vlab_id,
project_id=thread.project_id,
)
return thread


def get_mcp_client(request: Request) -> MCPClient | None:
"""Get the MCP client from the app state."""
if request.app.state.mcp_client is None:
Expand Down Expand Up @@ -415,6 +444,72 @@ async def get_selected_tools(
return selected_tools


async def filtered_tools(
request: Request,
thread: Annotated[Threads, Depends(get_thread)],
tool_list: Annotated[list[type[BaseTool]], Depends(get_selected_tools)],
openai_client: Annotated[AsyncOpenAI, Depends(get_openai_client)],
settings: Annotated[Settings, Depends(get_settings)],
) -> list[type[BaseTool]]:
"""Based on the current conversation, select relevant tools."""
if request.method == "GET":
return tool_list

# Awaiting here makes downstream calls already loaded so no performance issue
messages = await thread.awaitable_attrs.messages
openai_messages = await messages_to_openai_content(messages)

# Remove the content of tool responses to save tokens
for message in openai_messages:
if message["role"] == "tool":
message["content"] = "..."

body = await request.json()
openai_messages.append({"role": "user", "content": body["content"]})

system_prompt = f"""TASK: Filter tools for AI agent based on conversation relevance.

AVAILABLE TOOLS:
{chr(10).join(f"{tool.name}: {tool.description}" for tool in tool_list)}

INSTRUCTIONS:
1. Analyze the conversation to identify required capabilities
2. Select at least 5 of the most relevant tools by name only
3. BIAS TOWARD INCLUSION: If uncertain about a tool's relevance, include it - better to provide too many tools than too few
4. Only exclude tools that are clearly irrelevant to the conversation
5. Output format: comma-separated list of tool names
6. Do not respond to user queries - only filter tools

OUTPUT: [tool_name1, tool_name2, ...]"""

tool_names = [tool.name for tool in tool_list]
TOOL_NAMES_LITERAL = Literal[*tool_names] if len(tool_names) > 0 else str

class ToolSelection(BaseModel):
"""Data class for tool selection by an LLM."""

selected_tools: list[TOOL_NAMES_LITERAL] = Field( # type: ignore
min_length=min(len(tool_names), settings.tools.min_tool_selection),
description=f"List of selected tool names, minimum {min(len(tool_names), settings.tools.min_tool_selection)} items. Must contain all of the tools relevant to the conversation.",
)

# Rest of your code remains the same
response = await openai_client.beta.chat.completions.parse(
messages=[{"role": "system", "content": system_prompt}, *openai_messages], # type: ignore
model="gpt-4o-mini",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be turned into an env var but we have enough already. Please let me know if you would prefer an env var.

response_format=ToolSelection,
)

if response.choices[0].message.parsed:
selected_tools = response.choices[0].message.parsed.selected_tools
logger.info(
f"QUERY: {body['content']}, #TOOLS: {len(selected_tools)}, SELECTED TOOLS: {selected_tools}"
)
return [tool for tool in tool_list if tool.name in selected_tools]
else:
return []


@cache
def get_rules_dir() -> Path:
"""Get the path to the rules directory."""
Expand Down Expand Up @@ -476,7 +571,7 @@ def get_system_prompt(rules_dir: Annotated[Path, Depends(get_rules_dir)]) -> str


def get_starting_agent(
tool_list: Annotated[list[type[BaseTool]], Depends(get_selected_tools)],
tool_list: Annotated[list[type[BaseTool]], Depends(filtered_tools)],
system_prompt: Annotated[str, Depends(get_system_prompt)],
) -> Agent:
"""Get the starting agent."""
Expand All @@ -488,33 +583,6 @@ def get_starting_agent(
return agent


async def get_thread(
user_info: Annotated[UserInfo, Depends(get_user_info)],
thread_id: str,
session: Annotated[AsyncSession, Depends(get_session)],
) -> Threads:
"""Check if the current thread / user matches."""
thread_result = await session.execute(
select(Threads).where(
Threads.user_id == user_info.sub, Threads.thread_id == thread_id
)
)
thread = thread_result.scalars().one_or_none()
if not thread:
raise HTTPException(
status_code=404,
detail={
"detail": "Thread not found.",
},
)
validate_project(
groups=user_info.groups,
virtual_lab_id=thread.vlab_id,
project_id=thread.project_id,
)
return thread


def get_semantic_routes(request: Request) -> SemanticRouter | None:
"""Get the semantic route object for basic guardrails."""
return request.app.state.semantic_router
Expand Down
11 changes: 2 additions & 9 deletions backend/src/neuroagent/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,8 @@ async def stream_chat_agent(
agent.model = agent.model.removeprefix("openai/")
agents_routine.client = openai_client

messages: list[Messages] = await thread.awaitable_attrs.messages
# Since the session is not reinstantiated in stream.py
# we need to lazy load the tool_calls in advance since in
# any case they will be needed to convert the db schema
# to OpenAI messages
for msg in messages:
if msg.entity == Entity.AI_TOOL:
# This awaits the lazy loading, ensuring tool_calls is populated now.
await msg.awaitable_attrs.tool_calls
# No need to await since it has been awaited in tool filtering dependency
messages: list[Messages] = thread.messages

if (
not messages
Expand Down
5 changes: 2 additions & 3 deletions backend/src/neuroagent/rules/ai-message-formatting.mdc
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ description: AI-message formatting guidelines
---
# Message formatting guidelines

## Structure and Organization
- Use hierarchical headers (##, ###) to organize information by categories and subcategories
# AI Message Formatting Guidelines
- Always use rich markdown formatting.

## Content Structure Guidelines

- Use hierarchical headers (##, ###) to organize information by categories and subcategories
- Create detailed subsections for each specific area or entity
- Group related items under appropriate subheaders
- Use descriptive section titles that specify the exact areas or categories
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: https://staging.openbraininstitute.org/api/entitycore/openapi.json
# timestamp: 2025-07-09T08:29:12+00:00
# timestamp: 2025-07-09T08:41:55+00:00

from __future__ import annotations

Expand Down
26 changes: 1 addition & 25 deletions backend/src/neuroagent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,7 @@ async def messages_to_openai_content(
messages = []
if db_messages:
for msg in db_messages:
if msg.content and msg.entity == Entity.AI_TOOL:
# Load the base content
content = json.loads(msg.content)

# Get the associated tool calls
tool_calls = msg.tool_calls

# Format it back into the json OpenAI expects
tool_calls_content = [
{
"function": {
"arguments": tool_call.arguments,
"name": tool_call.name,
},
"id": tool_call.tool_call_id,
"type": "function",
}
for tool_call in tool_calls
]

# Assign it back to the main content
content["tool_calls"] = tool_calls_content
messages.append(content)
else:
messages.append(json.loads(msg.content))
messages.append(json.loads(msg.content))

return messages

Expand Down
18 changes: 16 additions & 2 deletions backend/tests/app/routers/test_qa.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Annotated
from unittest.mock import Mock

import pytest
from fastapi import Depends

from neuroagent.app.config import Settings
from neuroagent.app.database.sql_schemas import Threads
from neuroagent.app.dependencies import (
filtered_tools,
get_agents_routine,
get_openai_client,
get_settings,
get_thread,
)
from neuroagent.app.main import app
from neuroagent.app.routers import qa
Expand Down Expand Up @@ -178,8 +183,17 @@ def test_chat_streamed(app_client, httpx_mock, patch_required_env, db_connection
rate_limiter={"disabled": True},
)
app.dependency_overrides[get_settings] = lambda: test_settings
agent_routine = Mock()
app.dependency_overrides[get_agents_routine] = lambda: agent_routine
app.dependency_overrides[get_agents_routine] = lambda: Mock()

async def mock_dependency_with_lazy_loading(
thread: Annotated[Threads, Depends(get_thread)],
):
# Thread is injected automatically
# Perform the lazy loading
await thread.awaitable_attrs.messages
return []

app.dependency_overrides[filtered_tools] = mock_dependency_with_lazy_loading

expected_tokens = (
b"Calling tool : resolve_entities_tool with arguments : {brain_region:"
Expand Down