Skip to content

Commit 69035b3

Browse files
committed
refactor: moved dual agent to rai_bench
1 parent 1978020 commit 69035b3

File tree

5 files changed

+125
-98
lines changed

5 files changed

+125
-98
lines changed

src/rai_bench/rai_bench/agents.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import logging
17+
from functools import partial
18+
from typing import List, Optional
19+
20+
from langchain.chat_models.base import BaseChatModel
21+
from langchain_core.messages import (
22+
AIMessage,
23+
BaseMessage,
24+
HumanMessage,
25+
)
26+
from langchain_core.tools import BaseTool
27+
from langgraph.graph import START, StateGraph
28+
from langgraph.graph.state import CompiledStateGraph
29+
from langgraph.prebuilt.tool_node import tools_condition
30+
from rai.agents.langchain.core.conversational_agent import State, agent
31+
from rai.agents.langchain.core.tool_runner import ToolRunner
32+
33+
34+
def multimodal_to_tool_bridge(state: State):
35+
"""Node of langchain workflow designed to bridge
36+
nodes with llms. Removing images for context
37+
"""
38+
39+
cleaned_messages: List[BaseMessage] = []
40+
for msg in state["messages"]:
41+
if isinstance(msg, HumanMessage):
42+
# Remove images but keep the direct request
43+
if isinstance(msg.content, list):
44+
# Extract text only
45+
text_parts = [
46+
part.get("text", "")
47+
for part in msg.content
48+
if isinstance(part, dict) and part.get("type") == "text"
49+
]
50+
if text_parts:
51+
cleaned_messages.append(HumanMessage(content=" ".join(text_parts)))
52+
else:
53+
cleaned_messages.append(msg)
54+
elif isinstance(msg, AIMessage):
55+
# Keep AI messages for context
56+
cleaned_messages.append(msg)
57+
58+
state["messages"] = cleaned_messages
59+
return state
60+
61+
62+
def create_multimodal_to_tool_agent(
63+
multimodal_llm: BaseChatModel,
64+
tool_llm: BaseChatModel,
65+
tools: List[BaseTool],
66+
multimodal_system_prompt: str,
67+
tool_system_prompt: str,
68+
logger: Optional[logging.Logger] = None,
69+
debug: bool = False,
70+
) -> CompiledStateGraph:
71+
"""
72+
Creates an agent flow where inputs first go to a multimodal LLM,
73+
then its output is passed to a tool-calling LLM.
74+
Can be usefull when multimodal llm does not provide tool calling.
75+
76+
Args:
77+
tools: List of tools available to the tool agent
78+
79+
Returns:
80+
Compiled state graph
81+
"""
82+
_logger = None
83+
if logger:
84+
_logger = logger
85+
else:
86+
_logger = logging.getLogger(__name__)
87+
88+
_logger.info("Creating multimodal to tool agent flow")
89+
90+
tool_llm_with_tools = tool_llm.bind_tools(tools)
91+
tool_node = ToolRunner(tools=tools, logger=_logger)
92+
93+
workflow = StateGraph(State)
94+
workflow.add_node(
95+
"thinker",
96+
partial(agent, multimodal_llm, _logger, multimodal_system_prompt),
97+
)
98+
# context bridge for altering the
99+
workflow.add_node(
100+
"context_bridge",
101+
multimodal_to_tool_bridge,
102+
)
103+
workflow.add_node(
104+
"tool_agent",
105+
partial(agent, tool_llm_with_tools, _logger, tool_system_prompt),
106+
)
107+
workflow.add_node("tools", tool_node)
108+
109+
workflow.add_edge(START, "thinker")
110+
workflow.add_edge("thinker", "context_bridge")
111+
workflow.add_edge("context_bridge", "tool_agent")
112+
113+
workflow.add_conditional_edges(
114+
"tool_agent",
115+
tools_condition,
116+
)
117+
118+
# Tool node goes back to tool agent
119+
workflow.add_edge("tools", "tool_agent")
120+
121+
app = workflow.compile(debug=debug)
122+
_logger.info("Multimodal to tool agent flow created")
123+
return app

src/rai_bench/rai_bench/manipulation_o3de/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from launch_ros.substitutions import FindPackageShare
3535
from rai.agents.langchain.core import (
3636
create_conversational_agent,
37-
create_multimodal_to_tool_agent,
3837
)
3938
from rai.communication.ros2.connectors import ROS2Connector
4039
from rai.messages import HumanMultimodalMessage
@@ -46,6 +45,7 @@
4645
)
4746
from rai_open_set_vision.tools import GetGrabbingPointTool
4847

48+
from rai_bench.agents import create_multimodal_to_tool_agent
4949
from rai_bench.base_benchmark import BaseBenchmark, RunSummary, TimeoutException
5050
from rai_bench.manipulation_o3de.interfaces import Task
5151
from rai_bench.manipulation_o3de.results_tracking import (

src/rai_bench/rai_bench/tool_calling_agent/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
from langgraph.graph.state import CompiledStateGraph
2626
from rai.agents.langchain.core import (
2727
create_conversational_agent,
28-
create_multimodal_to_tool_agent,
2928
)
3029
from rai.messages import HumanMultimodalMessage
3130

31+
from rai_bench.agents import create_multimodal_to_tool_agent
3232
from rai_bench.base_benchmark import BaseBenchmark, TimeoutException
3333
from rai_bench.results_processing.langfuse_scores_tracing import ScoreTracingHandler
3434
from rai_bench.tool_calling_agent.interfaces import (

src/rai_core/rai/agents/langchain/core/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .conversational_agent import State as ConversationalAgentState
1616
from .conversational_agent import (
1717
create_conversational_agent,
18-
create_multimodal_to_tool_agent,
1918
)
2019
from .react_agent import (
2120
ReActAgentState,
@@ -29,7 +28,6 @@
2928
"ReActAgentState",
3029
"ToolRunner",
3130
"create_conversational_agent",
32-
"create_multimodal_to_tool_agent",
3331
"create_react_runnable",
3432
"create_state_based_runnable",
3533
]

src/rai_core/rai/agents/langchain/core/conversational_agent.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919

2020
from langchain.chat_models.base import BaseChatModel
2121
from langchain_core.messages import (
22-
AIMessage,
2322
BaseMessage,
24-
HumanMessage,
2523
SystemMessage,
2624
)
2725
from langchain_core.tools import BaseTool
@@ -94,95 +92,3 @@ def create_conversational_agent(
9492
app = workflow.compile(debug=debug)
9593
_logger.info("State based agent created")
9694
return app
97-
98-
99-
def multimodal_to_tool_bridge(state: State):
100-
"""Node of langchain workflow designed to bridge
101-
nodes with llms. Removing images for context
102-
"""
103-
104-
cleaned_messages: List[BaseMessage] = []
105-
for msg in state["messages"]:
106-
if isinstance(msg, HumanMessage):
107-
# Remove images but keep the direct request
108-
if isinstance(msg.content, list):
109-
# Extract text only
110-
text_parts = [
111-
part.get("text", "")
112-
for part in msg.content
113-
if isinstance(part, dict) and part.get("type") == "text"
114-
]
115-
if text_parts:
116-
cleaned_messages.append(HumanMessage(content=" ".join(text_parts)))
117-
else:
118-
cleaned_messages.append(msg)
119-
elif isinstance(msg, AIMessage):
120-
# Keep AI messages for context
121-
cleaned_messages.append(msg)
122-
123-
state["messages"] = cleaned_messages
124-
return state
125-
126-
127-
def create_multimodal_to_tool_agent(
128-
multimodal_llm: BaseChatModel,
129-
tool_llm: BaseChatModel,
130-
tools: List[BaseTool],
131-
multimodal_system_prompt: str,
132-
tool_system_prompt: str,
133-
logger: Optional[logging.Logger] = None,
134-
debug: bool = False,
135-
) -> CompiledStateGraph:
136-
"""
137-
Creates an agent flow where inputs first go to a multimodal LLM,
138-
then its output is passed to a tool-calling LLM.
139-
Can be usefull when multimodal llm does not provide tool calling.
140-
141-
Args:
142-
tools: List of tools available to the tool agent
143-
144-
Returns:
145-
Compiled state graph
146-
"""
147-
_logger = None
148-
if logger:
149-
_logger = logger
150-
else:
151-
_logger = logging.getLogger(__name__)
152-
153-
_logger.info("Creating multimodal to tool agent flow")
154-
155-
tool_llm_with_tools = tool_llm.bind_tools(tools)
156-
tool_node = ToolRunner(tools=tools, logger=_logger)
157-
158-
workflow = StateGraph(State)
159-
workflow.add_node(
160-
"thinker",
161-
partial(agent, multimodal_llm, _logger, multimodal_system_prompt),
162-
)
163-
# context bridge for altering the
164-
workflow.add_node(
165-
"context_bridge",
166-
multimodal_to_tool_bridge,
167-
)
168-
workflow.add_node(
169-
"tool_agent",
170-
partial(agent, tool_llm_with_tools, _logger, tool_system_prompt),
171-
)
172-
workflow.add_node("tools", tool_node)
173-
174-
workflow.add_edge(START, "thinker")
175-
workflow.add_edge("thinker", "context_bridge")
176-
workflow.add_edge("context_bridge", "tool_agent")
177-
178-
workflow.add_conditional_edges(
179-
"tool_agent",
180-
tools_condition,
181-
)
182-
183-
# Tool node goes back to tool agent
184-
workflow.add_edge("tools", "tool_agent")
185-
186-
app = workflow.compile(debug=debug)
187-
_logger.info("Multimodal to tool agent flow created")
188-
return app

0 commit comments

Comments
 (0)