Skip to content

Commit b4f714b

Browse files
boczekbartekmaciejmajek
authored andcommitted
feat: add simple rclpy based function calling (#11)
* feat(tools): rclpy based ros2 publisher and single msg reader * refactor * fix: remove misleading docstrings * add typehints * fix: rclpy ros tools docs * don't add frame_id to msg header by default * Update src/rai/tools/ros/utils.py Co-authored-by: Maciej Majek <[email protected]> * fix: ros2 pub tool * refactor: rename `rclpy` scripts to `native` - discussion: #11 (comment) * ros native tools refactor to use global node * feat(`tool_runner`): allow empty args_schema + logging chore: pre-commit * revert('ScenarioRunner`): convert to node * refactor: change logging to logger * logger.error -> logger.warning for tool call error * fix: husarion_poc_example_ros_native * fix logger * Update src/rai/scenario_engine/tool_runner.py Co-authored-by: Maciej Majek <[email protected]> --------- Co-authored-by: Maciej Majek <[email protected]>
1 parent 2a7574c commit b4f714b

File tree

6 files changed

+282
-11
lines changed

6 files changed

+282
-11
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import logging
2+
import os
3+
from typing import List
4+
5+
import rclpy
6+
from langchain_core.messages import HumanMessage, SystemMessage
7+
from langchain_openai import ChatOpenAI
8+
from rclpy.node import Node
9+
10+
from rai.scenario_engine.messages import AgentLoop
11+
from rai.scenario_engine.scenario_engine import ScenarioPartType, ScenarioRunner
12+
from rai.tools.ros.cat_demo_tools import FinishTool
13+
from rai.tools.ros.native import (
14+
Ros2GetOneMsgFromTopicTool,
15+
Ros2GetTopicsNamesAndTypesTool,
16+
Ros2PubMessageTool,
17+
)
18+
19+
20+
def main():
21+
scenario: List[ScenarioPartType] = [
22+
SystemMessage(
23+
content="You are an autonomous agent. Your main goal is to fulfill the user's requests. "
24+
"Do not make assumptions about the environment you are currently in. "
25+
"Use the tooling provided to gather information about the environment."
26+
),
27+
HumanMessage(content="The robot is moving. Send robot to the random location"),
28+
AgentLoop(stop_action=FinishTool().__class__.__name__, stop_iters=50),
29+
]
30+
31+
log_usage = all((os.getenv("LANGFUSE_PK"), os.getenv("LANGFUSE_SK")))
32+
llm = ChatOpenAI(model="gpt-4o")
33+
34+
rclpy.init()
35+
36+
rai_node = Node("rai") # type: ignore
37+
38+
tools = [
39+
Ros2GetTopicsNamesAndTypesTool(),
40+
Ros2GetOneMsgFromTopicTool(node=rai_node),
41+
Ros2PubMessageTool(node=rai_node),
42+
FinishTool(),
43+
]
44+
45+
runner = ScenarioRunner(
46+
scenario,
47+
llm,
48+
tools=tools,
49+
llm_type="openai",
50+
scenario_name="Husarion example",
51+
log_usage=log_usage,
52+
logging_level=logging.INFO,
53+
)
54+
55+
runner.run()
56+
rai_node.destroy_node()
57+
58+
rclpy.shutdown()
59+
60+
61+
if __name__ == "__main__":
62+
main()

poetry.lock

Lines changed: 47 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ piper-tts = "^1.2.0"
3131
tabulate = "^0.9.0"
3232
lark = "^1.1.9"
3333
langfuse = "^2.36.1"
34+
netifaces = "^0.11.0"
3435

3536
[tool.poetry.group.dev.dependencies]
3637
ipykernel = "^6.29.4"

src/rai/scenario_engine/tool_runner.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Any, Dict, List, Literal, Sequence
23

34
from langchain.tools import BaseTool
@@ -22,21 +23,32 @@ def images_to_vendor_format(images: List[str], vendor: str) -> List[Dict[str, An
2223

2324

2425
def run_tool_call(
25-
tool_call: ToolCall, tools: Sequence[BaseTool]
26+
tool_call: ToolCall,
27+
tools: Sequence[BaseTool],
2628
) -> Dict[str, Any] | Any:
29+
logger = logging.getLogger(__name__)
2730
selected_tool = {k.name: k for k in tools}[tool_call["name"]]
31+
2832
try:
29-
args = selected_tool.args_schema(**tool_call["args"]) # type: ignore
33+
if selected_tool.args_schema is not None:
34+
args = selected_tool.args_schema(**tool_call["args"]).dict()
35+
else:
36+
args = dict()
3037
except Exception as e:
31-
return f"Error in preparing arguments for {selected_tool.name}: {e}"
38+
err_msg = f"Error in preparing arguments for {selected_tool.name}: {e}"
39+
logger.error(err_msg)
40+
return err_msg
3241

33-
print(f"Running tool: {selected_tool.name} with args: {args.dict()}")
42+
logger.info(f"Running tool: {selected_tool.name} with args: {args}")
3443

3544
try:
36-
tool_output = selected_tool.run(args.dict())
45+
tool_output = selected_tool.run(args)
3746
except Exception as e:
38-
return f"Error running tool {selected_tool.name}: {e}"
47+
err_msg = f"Error in running tool {selected_tool.name}: {e}"
48+
logger.warning(err_msg)
49+
return err_msg
3950

51+
logger.info(f"Successfully ran tool: {selected_tool.name}. Output: {tool_output}")
4052
return tool_output
4153

4254

src/rai/tools/ros/native.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from typing import Any, Dict, Tuple, Type
2+
3+
import rclpy
4+
from langchain.tools import BaseTool
5+
from langchain_core.pydantic_v1 import BaseModel, Field
6+
from rclpy.impl.rcutils_logger import RcutilsLogger
7+
from rclpy.node import Node
8+
from ros2cli.node.strategy import NodeStrategy
9+
from rosidl_runtime_py.set_message import set_message_fields
10+
11+
from rai.communication.ros_communication import wait_for_message
12+
13+
from .utils import import_message_from_str
14+
15+
16+
class Ros2BaseInput(BaseModel):
17+
"""Empty input for ros2 tool"""
18+
19+
20+
class Ros2BaseTool(BaseTool):
21+
node: Node = Field(..., exclude=True, include=False, required=True)
22+
23+
args_schema: Type[Ros2BaseInput] = Ros2BaseInput
24+
25+
@property
26+
def logger(self) -> RcutilsLogger:
27+
return self.node.get_logger()
28+
29+
30+
class Ros2GetTopicsNamesAndTypesTool(BaseTool):
31+
name: str = "Ros2GetTopicsNamesAndTypes"
32+
description: str = "A tool for getting all ros2 topics names and types"
33+
34+
def _run(self):
35+
with NodeStrategy(dict()) as node:
36+
return [
37+
(topic_name, topic_type)
38+
for topic_name, topic_type in node.get_topic_names_and_types()
39+
if len(topic_name.split("/")) <= 2
40+
]
41+
42+
43+
class Ros2GetOneMsgFromTopicInput(BaseModel):
44+
"""Input for the get_current_position tool."""
45+
46+
topic: str = Field(..., description="Ros2 topic")
47+
msg_type: str = Field(
48+
..., description="Type of ros2 message in typical ros2 format."
49+
)
50+
timeout_sec: int = Field(
51+
10, description="The time in seconds to wait for a message to be received."
52+
)
53+
54+
55+
class Ros2GetOneMsgFromTopicTool(Ros2BaseTool):
56+
"""Get one message from a specific ros2 topic"""
57+
58+
name: str = "Ros2GetOneMsgFromTopic"
59+
description: str = "A tool for getting one message from a ros2 topic"
60+
61+
args_schema: Type[Ros2GetOneMsgFromTopicInput] = Ros2GetOneMsgFromTopicInput
62+
63+
def _run(self, topic: str, msg_type: str, timeout_sec: int):
64+
"""Gets the current position from the specified topic."""
65+
msg_cls: Type = import_message_from_str(msg_type)
66+
67+
qos_profile = (
68+
rclpy.qos.qos_profile_sensor_data
69+
) # TODO(@boczekbartek): infer QoS from topic
70+
71+
success, msg = wait_for_message(
72+
msg_cls,
73+
self.node,
74+
topic,
75+
qos_profile=qos_profile,
76+
time_to_wait=timeout_sec,
77+
)
78+
msg = msg.get_data()
79+
80+
if success:
81+
self.logger.info(f"Received message of type {msg_type} from topic {topic}")
82+
else:
83+
self.logger.error(
84+
f"Failed to receive message of type {msg_type} from topic {topic}"
85+
)
86+
87+
if msg is None:
88+
return {"content": "No message received."}
89+
90+
return {
91+
"content": str(msg),
92+
}
93+
94+
95+
class PubRos2MessageToolInput(BaseModel):
96+
topic_name: str = Field(..., description="Ros2 topic to publish the message")
97+
msg_type: str = Field(
98+
..., description="Type of ros2 message in typical ros2 format."
99+
)
100+
msg_args: Dict[str, Any] = Field(
101+
..., description="The arguments of the service call."
102+
)
103+
104+
105+
class Ros2PubMessageTool(Ros2BaseTool):
106+
name: str = "PubRos2MessageTool"
107+
description: str = """A tool for publishing a message to a ros2 topic
108+
Example usage:
109+
110+
```python
111+
tool = Ros2PubMessageTool()
112+
tool.run(
113+
{
114+
"topic_name": "/some_topic",
115+
"msg_type": "geometry_msgs/Point",
116+
"msg_args": {"x": 0.0, "y": 0.0, "z": 0.0},
117+
}
118+
)
119+
120+
```
121+
"""
122+
123+
args_schema: Type[PubRos2MessageToolInput] = PubRos2MessageToolInput
124+
125+
def _build_msg(
126+
self, msg_type: str, msg_args: Dict[str, Any]
127+
) -> Tuple[object, Type]:
128+
msg_cls: Type = import_message_from_str(msg_type)
129+
msg = msg_cls()
130+
set_message_fields(msg, msg_args)
131+
return msg, msg_cls
132+
133+
def _run(self, topic_name: str, msg_type: str, msg_args: Dict[str, Any]):
134+
"""Publishes a message to the specified topic."""
135+
if "/msg/" not in msg_type:
136+
raise ValueError("msg_name must contain 'msg' in the name.")
137+
msg, msg_cls = self._build_msg(msg_type, msg_args)
138+
139+
publisher = self.node.create_publisher(
140+
msg_cls, topic_name, 10
141+
) # TODO(boczekbartek): infer qos profile from topic info
142+
143+
msg.header.stamp = self.node.get_clock().now().to_msg()
144+
publisher.publish(msg)

src/rai/tools/ros/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import Type
2+
3+
from rosidl_parser.definition import NamespacedType
4+
from rosidl_runtime_py.import_message import import_message_from_namespaced_type
5+
from rosidl_runtime_py.utilities import get_namespaced_type
6+
7+
8+
def import_message_from_str(msg_type: str) -> Type[object]:
9+
msg_namespaced_type: NamespacedType = get_namespaced_type(msg_type)
10+
return import_message_from_namespaced_type(msg_namespaced_type)

0 commit comments

Comments
 (0)