Skip to content

Commit d86e227

Browse files
committed
feat(tools): rclpy based ros2 publisher and single msg reader
1 parent d570a3a commit d86e227

File tree

3 files changed

+175
-0
lines changed

3 files changed

+175
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os
2+
from typing import List
3+
4+
import rclpy
5+
from langchain_core.messages import HumanMessage, SystemMessage
6+
from langchain_openai import ChatOpenAI
7+
8+
from rai.scenario_engine.messages import AgentLoop
9+
from rai.scenario_engine.scenario_engine import ScenarioPartType, ScenarioRunner
10+
from rai.tools.ros.cat_demo_tools import FinishTool
11+
from rai.tools.ros.rclpy import (
12+
Ros2GetOneMsgFromTopicTool,
13+
Ros2PubMessageTool,
14+
get_topics_names_and_types,
15+
)
16+
17+
18+
def main():
19+
tools = [
20+
get_topics_names_and_types,
21+
Ros2PubMessageTool(),
22+
Ros2GetOneMsgFromTopicTool(),
23+
FinishTool(),
24+
]
25+
26+
scenario: List[ScenarioPartType] = [
27+
SystemMessage(
28+
content="You are an autonomous agent. Your main goal is to fulfill the user's requests. "
29+
"Do not make assumptions about the environment you are currently in. "
30+
"Use the tooling provided to gather information about the environment."
31+
),
32+
HumanMessage(content="The robot is moving. Send robot to the random location"),
33+
AgentLoop(stop_action=FinishTool().__class__.__name__, stop_iters=50),
34+
]
35+
36+
log_usage = all((os.getenv("LANGFUSE_PK"), os.getenv("LANGFUSE_SK")))
37+
llm = ChatOpenAI(model="gpt-4o")
38+
39+
rclpy.init()
40+
runner = ScenarioRunner(
41+
scenario,
42+
llm,
43+
tools=tools,
44+
llm_type="openai",
45+
scenario_name="Husarion example",
46+
log_usage=log_usage,
47+
)
48+
runner.run()
49+
50+
rclpy.shutdown()
51+
52+
53+
if __name__ == "__main__":
54+
main()

src/rai/tools/ros/rclpy.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Any, Dict, Tuple, Type
2+
3+
import rclpy
4+
from langchain.tools import BaseTool, tool
5+
from langchain_core.pydantic_v1 import BaseModel, Field
6+
from rclpy.node import Node
7+
from rosidl_parser.definition import NamespacedType
8+
from rosidl_runtime_py.import_message import import_message_from_namespaced_type
9+
from rosidl_runtime_py.set_message import set_message_fields
10+
from rosidl_runtime_py.utilities import get_namespaced_type
11+
12+
from rai.communication.ros_communication import SingleMessageGrabber
13+
14+
from .utils import import_message_from_str
15+
16+
17+
@tool
18+
def get_topics_names_and_types():
19+
"""Call rclpy.node.Node().get_topics_names_and_types(). Return in a csv format. topic_name, serice_type"""
20+
21+
node = Node(node_name="rai_tool_node")
22+
rclpy.spin_once(node, timeout_sec=2)
23+
try:
24+
return [
25+
(topic_name, topic_type)
26+
for topic_name, topic_type in node.get_topic_names_and_types()
27+
if len(topic_name.split("/")) <= 2
28+
]
29+
finally:
30+
node.destroy_node()
31+
32+
33+
class Ros2GetOneMsgFromTopicInput(BaseModel):
34+
"""Input for the get_current_position tool."""
35+
36+
topic: str = Field(..., description="Ros2 topic")
37+
msg_type: str = Field(
38+
..., description="Type of ros2 message in typical ros2 format."
39+
)
40+
41+
42+
class Ros2GetOneMsgFromTopicTool(BaseTool):
43+
"""Get one message from a specific ros2 topic"""
44+
45+
name = "Ros2GetOneMsgFromTopic"
46+
description: str = "A tool for getting one message from a ros2 topic"
47+
48+
args_schema: Type[Ros2GetOneMsgFromTopicInput] = Ros2GetOneMsgFromTopicInput
49+
50+
def _run(self, topic: str, msg_type: str):
51+
"""Gets the current position from the specified topic."""
52+
msg_cls: Type = import_message_from_str(msg_type)
53+
54+
grabber = SingleMessageGrabber(topic, msg_cls, timeout_sec=10)
55+
msg = grabber.get_data()
56+
57+
if msg is None:
58+
return {"content": "Failed to get the position, wrong topic?"}
59+
60+
return {
61+
"content": str(msg),
62+
}
63+
64+
65+
class PubRos2MessageToolInput(BaseModel):
66+
"""Input for the set_goal_pose tool."""
67+
68+
topic_name: str = Field(..., description="Ros2 topic to publish the goal pose to")
69+
msg_type: str = Field(
70+
..., description="Type of ros2 message in typical ros2 format."
71+
)
72+
msg_args: Dict[str, Any] = Field(
73+
..., description="The arguments of the service call."
74+
)
75+
76+
77+
class Ros2PubMessageTool(BaseTool):
78+
"""Set the goal pose for the robot"""
79+
80+
name = "PubRos2MessageTool"
81+
description: str = "A tool for setting the goal pose for the robot."
82+
83+
args_schema: Type[PubRos2MessageToolInput] = PubRos2MessageToolInput
84+
85+
def _build_msg(
86+
self, msg_type: str, msg_args: Dict[str, Any]
87+
) -> Tuple[object, object]:
88+
msg_namespaced_type: NamespacedType = get_namespaced_type(msg_type)
89+
msg_cls = import_message_from_namespaced_type(msg_namespaced_type)
90+
msg = msg_cls()
91+
set_message_fields(msg, msg_args)
92+
return msg, msg_cls
93+
94+
def _run(self, topic_name: str, msg_type: str, msg_args: Dict[str, Any]):
95+
"""Sets the goal pose for the robot."""
96+
97+
msg, msg_cls = self._build_msg(msg_type, msg_args)
98+
99+
node = Node(node_name="RAI_PubRos2MessageTool")
100+
publisher = node.create_publisher(
101+
msg_cls, topic_name, 10
102+
) # TODO(boczekbartek): infer qos profile from topic info
103+
try:
104+
msg.header.stamp = node.get_clock().now().to_msg()
105+
msg.header.frame_id = "map"
106+
107+
rclpy.spin_once(node)
108+
publisher.publish(msg)
109+
finally:
110+
node.destroy_publisher(publisher)
111+
node.destroy_node()

src/rai/tools/ros/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from typing import Type
2+
3+
from rosidl_parser.definition import NamespacedType
4+
from rosidl_runtime_py.import_message import import_message_from_namespaced_type
5+
from rosidl_runtime_py.utilities import get_namespaced_type
6+
7+
8+
def import_message_from_str(msg_type: str) -> Type:
9+
msg_namespaced_type: NamespacedType = get_namespaced_type(msg_type)
10+
return import_message_from_namespaced_type(msg_namespaced_type)

0 commit comments

Comments
 (0)