Skip to content

Commit de67ecb

Browse files
committed
feat(tool_runner): allow empty args_schema + logging
chore: pre-commit
1 parent ac34283 commit de67ecb

File tree

4 files changed

+40
-26
lines changed

4 files changed

+40
-26
lines changed

examples/husarion_poc_example_ros_native.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,17 @@
44
import rclpy
55
from langchain_core.messages import HumanMessage, SystemMessage
66
from langchain_openai import ChatOpenAI
7+
from rclpy.node import Node
78

89
from rai.scenario_engine.messages import AgentLoop
910
from rai.scenario_engine.scenario_engine import ScenarioPartType, ScenarioRunner
1011
from rai.tools.ros.cat_demo_tools import FinishTool
1112
from rai.tools.ros.native import (
1213
Ros2GetOneMsgFromTopicTool,
14+
Ros2GetTopicsNamesAndTypesTool,
1315
Ros2PubMessageTool,
14-
Ros2GetTopicsNamesAndTypesTool
1516
)
1617

17-
from rclpy.node import Node
18-
from ros2cli.node.strategy import NodeStrategy
19-
2018

2119
def main():
2220
scenario: List[ScenarioPartType] = [
@@ -34,25 +32,25 @@ def main():
3432

3533
rclpy.init()
3634

37-
rai_node = Node("rai") # type: ignore
35+
rai_node = Node("rai") # type: ignore
3836

37+
tools = [
38+
Ros2GetTopicsNamesAndTypesTool(),
39+
Ros2GetOneMsgFromTopicTool(node=rai_node),
40+
Ros2PubMessageTool(node=rai_node),
41+
FinishTool(),
42+
]
3943

40-
runner = ScenarioRunner(
44+
runner = ScenarioRunner(
4145
scenario,
4246
llm,
47+
tools=tools,
4348
llm_type="openai",
4449
scenario_name="Husarion example",
4550
log_usage=log_usage,
46-
logging_level="DEBUG"
51+
logging_level="DEBUG",
4752
)
4853

49-
tools = [
50-
Ros2GetTopicsNamesAndTypesTool(),
51-
Ros2GetOneMsgFromTopicTool(node=rai_node),
52-
Ros2PubMessageTool(node=rai_node),
53-
FinishTool(),
54-
]
55-
5654
runner.bind_tools(tools)
5755

5856
runner.run()

src/rai/scenario_engine/scenario_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import pickle
55
from typing import Callable, Dict, List, Literal, Sequence, Union, cast
6-
from rclpy.node import Node
76

87
import coloredlogs
98
from langchain_core.language_models.chat_models import BaseChatModel
@@ -17,6 +16,7 @@
1716
from langchain_core.runnables import RunnableConfig
1817
from langchain_core.tools import BaseTool
1918
from langfuse.callback import CallbackHandler
19+
from rclpy.node import Node
2020

2121
from rai.history_saver import HistorySaver
2222
from rai.scenario_engine.messages import AgentLoop, FutureAiMessage
@@ -81,10 +81,10 @@ def __init__(
8181
log_usage: bool = True,
8282
use_cache: bool = False,
8383
):
84-
super().__init__(node_name="rai") # type: ignore
84+
super().__init__(node_name="rai") # type: ignore
8585
self.scenario = scenario
8686
self.tools = tools
87-
for t in self.tools: # TODO(@boczekbartek): refactor to the method
87+
for t in self.tools: # TODO(@boczekbartek): refactor to the method
8888
if isinstance(t, BaseRos2NativeTool):
8989
t.set_node(self)
9090
self.log_usage = log_usage

src/rai/scenario_engine/tool_runner.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Any, Dict, List, Literal, Sequence
23

34
from langchain.tools import BaseTool
@@ -22,21 +23,31 @@ def images_to_vendor_format(images: List[str], vendor: str) -> List[Dict[str, An
2223

2324

2425
def run_tool_call(
25-
tool_call: ToolCall, tools: Sequence[BaseTool]
26+
tool_call: ToolCall,
27+
tools: Sequence[BaseTool],
2628
) -> Dict[str, Any] | Any:
2729
selected_tool = {k.name: k for k in tools}[tool_call["name"]]
30+
2831
try:
29-
args = selected_tool.args_schema(**tool_call["args"]) # type: ignore
32+
if selected_tool.args_schema is not None:
33+
args = selected_tool.args_schema(**tool_call["args"]).dict()
34+
else:
35+
args = dict()
3036
except Exception as e:
31-
return f"Error in preparing arguments for {selected_tool.name}: {e}"
37+
err_msg = f"Error in preparing arguments for {selected_tool.name}: {e}"
38+
logging.error(err_msg)
39+
return err_msg
3240

33-
print(f"Running tool: {selected_tool.name} with args: {args.dict()}")
41+
logging.info(f"Running tool: {selected_tool.name} with args: {args}")
3442

3543
try:
36-
tool_output = selected_tool.run(args.dict())
44+
tool_output = selected_tool.run(args)
3745
except Exception as e:
38-
return f"Error running tool {selected_tool.name}: {e}"
46+
err_msg = f"Error in running tool {selected_tool.name}: {e}"
47+
logging.error(err_msg)
48+
return err_msg
3949

50+
logging.info(f"Successfully ran tool: {selected_tool.name}. Output: {tool_output}")
4051
return tool_output
4152

4253

src/rai/tools/ros/native.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212

1313
from .utils import import_message_from_str
1414

15+
1516
class Ros2BaseInput(BaseModel):
16-
""" Empty input for ros2 tool """
17+
"""Empty input for ros2 tool"""
18+
1719

1820
class Ros2BaseTool(BaseTool):
1921
node: Node = Field(..., exclude=True, include=False, required=True)
@@ -24,10 +26,11 @@ class Ros2BaseTool(BaseTool):
2426
def logger(self) -> RcutilsLogger:
2527
return self.node.get_logger()
2628

29+
2730
class Ros2GetTopicsNamesAndTypesTool(BaseTool):
2831
name: str = "Ros2GetTopicsNamesAndTypes"
2932
description: str = "A tool for getting all ros2 topics names and types"
30-
33+
3134
def _run(self):
3235
with NodeStrategy(dict()) as node:
3336
return [
@@ -61,7 +64,9 @@ def _run(self, topic: str, msg_type: str, timeout_sec: int):
6164
"""Gets the current position from the specified topic."""
6265
msg_cls: Type = import_message_from_str(msg_type)
6366

64-
qos_profile = rclpy.qos.qos_profile_sensor_data # TODO(@boczekbartek): infer QoS from topic
67+
qos_profile = (
68+
rclpy.qos.qos_profile_sensor_data
69+
) # TODO(@boczekbartek): infer QoS from topic
6570
success, msg = wait_for_message(
6671
msg_cls,
6772
self.node,

0 commit comments

Comments
 (0)