Skip to content

feat: ros2 connector #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions src/rai/rai/communication/ari_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Generic, Optional, TypeVar

from pydantic import Field
from typing import Any, Dict, Generic, Optional, TypeVar

from .base_connector import BaseConnector, BaseMessage

Expand All @@ -26,15 +24,14 @@ class ARIMessage(BaseMessage):
Inherit from this class to create specific ARI message types.
"""


# TODO: Move this to ros2 module
class ROS2RRIMessage(ARIMessage):
ros_message_type: str = Field(
description="The string representation of the ROS message type (e.g. 'std_msgs/msg/String')"
)
python_message_class: Optional[type] = Field(
description="The Python class of the ROS message type", default=None
)
def __init__(
self,
payload: Any,
metadata: Optional[Dict[str, Any]] = None,
*args: Any,
**kwargs: Any,
):
super().__init__(payload, metadata, *args, **kwargs)


T = TypeVar("T", bound=ARIMessage)
Expand Down
17 changes: 14 additions & 3 deletions src/rai/rai/communication/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,26 @@
# limitations under the License.

from abc import abstractmethod
from typing import Any, Callable, Generic, Optional, Protocol, TypeVar
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from uuid import uuid4


class BaseMessage(Protocol):
class BaseMessage:
payload: Any
metadata: Dict[str, Any]

def __init__(self, payload: Any, *args, **kwargs):
def __init__(
self,
payload: Any,
metadata: Optional[Dict[str, Any]] = None,
*args: Any,
**kwargs: Any,
):
self.payload = payload
if metadata is None:
self.metadata = {}
else:
self.metadata = metadata


T = TypeVar("T", bound=BaseMessage)
Expand Down
99 changes: 90 additions & 9 deletions src/rai/rai/communication/ros2/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Annotated, Any, Dict, List, Optional, Tuple, Type, TypedDict, cast
from typing import (
Annotated,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypedDict,
cast,
)

import rclpy
import rclpy.callback_groups
Expand Down Expand Up @@ -146,7 +159,7 @@ def publish(
topic: str,
msg_content: Dict[str, Any],
msg_type: str,
*, # Force keyword arguments
*,
auto_qos_matching: bool = True,
qos_profile: Optional[QoSProfile] = None,
) -> None:
Expand All @@ -170,11 +183,20 @@ def publish(
publisher = self._get_or_create_publisher(topic, type(msg), qos_profile)
publisher.publish(msg)

def _verify_receive_args(
self, topic: str, auto_topic_type: bool, msg_type: Optional[str]
) -> None:
if auto_topic_type and msg_type is not None:
raise ValueError("Cannot provide both auto_topic_type and msg_type")
if not auto_topic_type and msg_type is None:
raise ValueError("msg_type must be provided if auto_topic_type is False")

def receive(
self,
topic: str,
msg_type: str,
*, # Force keyword arguments
*,
auto_topic_type: bool = True,
msg_type: Optional[str] = None,
timeout_sec: float = 1.0,
auto_qos_matching: bool = True,
qos_profile: Optional[QoSProfile] = None,
Expand All @@ -193,8 +215,20 @@ def receive(

Raises:
ValueError: If no publisher exists or no message is received within timeout
ValueError: If auto_topic_type is False and msg_type is not provided
ValueError: If auto_topic_type is True and msg_type is provided
"""
self._verify_publisher_exists(topic)
self._verify_receive_args(topic, auto_topic_type, msg_type)
topic_endpoints = self._verify_publisher_exists(topic)

# TODO: Verify publishers topic type consistency
if auto_topic_type:
msg_type = topic_endpoints[0].topic_type
else:
if msg_type is None:
raise ValueError(
"msg_type must be provided if auto_topic_type is False"
)

qos_profile = self._resolve_qos_profile(
topic, auto_qos_matching, qos_profile, for_publisher=False
Expand Down Expand Up @@ -260,16 +294,18 @@ def _get_message_class(msg_type: str) -> Type[Any]:
"""Convert message type string to actual message class."""
return import_message_from_str(msg_type)

def _verify_publisher_exists(self, topic: str) -> None:
def _verify_publisher_exists(self, topic: str) -> List[TopicEndpointInfo]:
"""Verify that at least one publisher exists for the given topic.

Raises:
ValueError: If no publisher exists for the topic
"""
if not self._node.get_publishers_info_by_topic(topic):
topic_endpoints = self._node.get_publishers_info_by_topic(topic)
if not topic_endpoints:
raise ValueError(f"No publisher found for topic: {topic}")
return topic_endpoints

def __del__(self) -> None:
def shutdown(self) -> None:
"""Cleanup publishers when object is destroyed."""
for publisher in self._publishers.values():
publisher.destroy()
Expand Down Expand Up @@ -324,18 +360,52 @@ def __init__(self, node: rclpy.node.Node) -> None:
self.node = node
self._logger = node.get_logger()
self.actions: Dict[str, ROS2ActionData] = {}
self._callback_executor = ThreadPoolExecutor(max_workers=10)

def _generate_handle(self):
return str(uuid.uuid4())

def _generic_callback(self, handle: str, feedback_msg: Any) -> None:
self.actions[handle]["feedbacks"].append(feedback_msg.feedback)

def _fan_out_feedback(
self, callbacks: List[Callable[[Any], None]], feedback_msg: Any
) -> None:
"""Fan out feedback message to multiple callbacks concurrently.

Args:
callbacks: List of callback functions to execute
feedback_msg: The feedback message to pass to each callback
"""
for callback in callbacks:
self._callback_executor.submit(
self._safe_callback_wrapper, callback, feedback_msg
)

def _safe_callback_wrapper(
self, callback: Callable[[Any], None], feedback_msg: Any
) -> None:
"""Safely execute a callback with error handling.

Args:
callback: The callback function to execute
feedback_msg: The feedback message to pass to the callback
"""
try:
callback(copy.deepcopy(feedback_msg))
except Exception as e:
self._logger.error(f"Error in feedback callback: {str(e)}")

def send_goal(
self,
action_name: str,
action_type: str,
goal: Dict[str, Any],
*,
feedback_callback: Callable[[Any], None] = lambda _: None,
done_callback: Callable[
[Any], None
] = lambda _: None, # TODO: handle done callback
timeout_sec: float = 1.0,
) -> Tuple[bool, Annotated[str, "action handle"]]:
handle = self._generate_handle()
Expand All @@ -355,8 +425,13 @@ def send_goal(
if not action_client.wait_for_server(timeout_sec=timeout_sec): # type: ignore
return False, ""

feedback_callbacks = [
partial(self._generic_callback, handle),
feedback_callback,
]
send_goal_future: Future = action_client.send_goal_async(
goal=action_goal, feedback_callback=partial(self._generic_callback, handle)
goal=action_goal,
feedback_callback=partial(self._fan_out_feedback, feedback_callbacks),
)
self.actions[handle]["action_client"] = action_client
self.actions[handle]["goal_future"] = send_goal_future
Expand All @@ -372,6 +447,7 @@ def send_goal(
return False, ""

get_result_future = cast(Future, goal_handle.get_result_async()) # type: ignore
get_result_future.add_done_callback(done_callback) # type: ignore

self.actions[handle]["result_future"] = get_result_future
self.actions[handle]["client_goal_handle"] = goal_handle
Expand Down Expand Up @@ -403,3 +479,8 @@ def get_result(self, handle: str) -> Any:
if self.actions[handle]["result_future"] is None:
raise ValueError(f"No result available for goal {handle}")
return self.actions[handle]["result_future"].result()

def shutdown(self) -> None:
"""Cleanup thread pool when object is destroyed."""
if hasattr(self, "_callback_executor"):
self._callback_executor.shutdown(wait=False)
126 changes: 126 additions & 0 deletions src/rai/rai/communication/ros2/connectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import uuid
from typing import Any, Callable, Dict, Optional

from rclpy.executors import MultiThreadedExecutor
from rclpy.node import Node

from rai.communication.ari_connector import ARIConnector, ARIMessage
from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI


class ROS2ARIMessage(ARIMessage):
def __init__(self, payload: Any, metadata: Optional[Dict[str, Any]] = None):
super().__init__(payload, metadata)


class ROS2ARIConnector(ARIConnector[ROS2ARIMessage]):
def __init__(
self, node_name: str = f"rai_ros2_ari_connector_{str(uuid.uuid4())[-12:]}"
):
super().__init__()
self._node = Node(node_name)
self._topic_api = ROS2TopicAPI(self._node)
self._service_api = ROS2ServiceAPI(self._node)
self._actions_api = ROS2ActionAPI(self._node)

self._executor = MultiThreadedExecutor()
self._executor.add_node(self._node)
self._thread = threading.Thread(target=self._executor.spin)
self._thread.start()

def send_message(self, message: ROS2ARIMessage, target: str):
auto_qos_matching = message.metadata.get("auto_qos_matching", True)
qos_profile = message.metadata.get("qos_profile", None)
msg_type = message.metadata.get("msg_type", None)

# TODO: allow msg_type to be None, add auto topic type detection
if msg_type is None:
raise ValueError("msg_type is required")

self._topic_api.publish(
topic=target,
msg_content=message.payload,
msg_type=msg_type,
auto_qos_matching=auto_qos_matching,
qos_profile=qos_profile,
)

def receive_message(
self,
source: str,
timeout_sec: float = 1.0,
msg_type: Optional[str] = None,
auto_topic_type: bool = True,
) -> ROS2ARIMessage:
msg = self._topic_api.receive(
topic=source,
timeout_sec=timeout_sec,
msg_type=msg_type,
auto_topic_type=auto_topic_type,
)
return ROS2ARIMessage(
payload=msg, metadata={"msg_type": str(type(msg)), "topic": source}
)

def service_call(
self, message: ROS2ARIMessage, target: str, timeout_sec: float = 1.0
) -> ROS2ARIMessage:
msg = self._service_api.call_service(
service_name=target,
service_type=message.metadata["msg_type"],
request=message.payload,
timeout_sec=timeout_sec,
)
return ROS2ARIMessage(
payload=msg, metadata={"msg_type": str(type(msg)), "service": target}
)

def start_action(
self,
action_data: Optional[ROS2ARIMessage],
target: str,
on_feedback: Callable[[Any], None] = lambda _: None,
on_done: Callable[[Any], None] = lambda _: None,
timeout_sec: float = 1.0,
) -> str:
if not isinstance(action_data, ROS2ARIMessage):
raise ValueError("Action data must be of type ROS2ARIMessage")
msg_type = action_data.metadata.get("msg_type", None)
if msg_type is None:
raise ValueError("msg_type is required")
accepted, handle = self._actions_api.send_goal(
action_name=target,
action_type=msg_type,
goal=action_data.payload,
timeout_sec=timeout_sec,
feedback_callback=on_feedback,
done_callback=on_done,
)
if not accepted:
raise RuntimeError("Action goal was not accepted")
return handle

def terminate_action(self, action_handle: str):
self._actions_api.terminate_goal(action_handle)

def shutdown(self):
self._executor.shutdown()
self._thread.join()
self._actions_api.shutdown()
self._topic_api.shutdown()
self._node.destroy_node()
Loading