Skip to content

Commit ff364b9

Browse files
committed
refactor
1 parent d86e227 commit ff364b9

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

examples/husarion_poc_example_rclpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
from rai.tools.ros.rclpy import (
1212
Ros2GetOneMsgFromTopicTool,
1313
Ros2PubMessageTool,
14-
get_topics_names_and_types,
14+
get_topics_names_and_types_tool,
1515
)
1616

1717

1818
def main():
1919
tools = [
20-
get_topics_names_and_types,
20+
get_topics_names_and_types_tool,
2121
Ros2PubMessageTool(),
2222
Ros2GetOneMsgFromTopicTool(),
2323
FinishTool(),

src/rai/tools/ros/rclpy.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,18 @@
44
from langchain.tools import BaseTool, tool
55
from langchain_core.pydantic_v1 import BaseModel, Field
66
from rclpy.node import Node
7-
from rosidl_parser.definition import NamespacedType
8-
from rosidl_runtime_py.import_message import import_message_from_namespaced_type
97
from rosidl_runtime_py.set_message import set_message_fields
10-
from rosidl_runtime_py.utilities import get_namespaced_type
118

129
from rai.communication.ros_communication import SingleMessageGrabber
1310

1411
from .utils import import_message_from_str
1512

1613

1714
@tool
18-
def get_topics_names_and_types():
15+
def get_topics_names_and_types_tool():
1916
"""Call rclpy.node.Node().get_topics_names_and_types(). Return in a csv format. topic_name, serice_type"""
2017

21-
node = Node(node_name="rai_tool_node")
18+
node = Node(node_name="rai_get_topics_names_and_types_tool")
2219
rclpy.spin_once(node, timeout_sec=2)
2320
try:
2421
return [
@@ -84,9 +81,8 @@ class Ros2PubMessageTool(BaseTool):
8481

8582
def _build_msg(
8683
self, msg_type: str, msg_args: Dict[str, Any]
87-
) -> Tuple[object, object]:
88-
msg_namespaced_type: NamespacedType = get_namespaced_type(msg_type)
89-
msg_cls = import_message_from_namespaced_type(msg_namespaced_type)
84+
) -> Tuple[object, Type]:
85+
msg_cls: Type = import_message_from_str(msg_type)
9086
msg = msg_cls()
9187
set_message_fields(msg, msg_args)
9288
return msg, msg_cls
@@ -96,7 +92,7 @@ def _run(self, topic_name: str, msg_type: str, msg_args: Dict[str, Any]):
9692

9793
msg, msg_cls = self._build_msg(msg_type, msg_args)
9894

99-
node = Node(node_name="RAI_PubRos2MessageTool")
95+
node = Node(node_name=f"rai_{self,__class__.__name__}")
10096
publisher = node.create_publisher(
10197
msg_cls, topic_name, 10
10298
) # TODO(boczekbartek): infer qos profile from topic info

0 commit comments

Comments
 (0)