Skip to content

Commit e54f7ff

Browse files
maciejmajekboczekbartek
authored andcommitted
feat: husarion toolset (#41)
* build: specify flake8 version * feat: add transform listener * refactor: use /tf for getting current pose * feat: wip husarion the describer demo example * feat: husarion the descriptor wip * refactor: rename tools.ros.cli_tools to tools.ros.cli fix: append human message if agent returns an aimsg without a tool call (hacky) fix: robot_y calculation bug * style: adjust too long line * refactor: rename SetWaypointTool to AddDescribedWaypointToDatabaseTool * refactor: set waypoint tool * style: use keyword arguments for cv2 functions * misc: update tool docstring, rename example * fix: remove -v flag as not all ros2 topic commands support it * docs: document new example
1 parent 60bfd4f commit e54f7ff

File tree

9 files changed

+322
-138
lines changed

9 files changed

+322
-138
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ repos:
4141
- id: black
4242

4343
- repo: https://github.com/pycqa/flake8
44-
rev: "" # pick a git hash / tag to point to
44+
rev: 7.1.0
4545
hooks:
4646
- id: flake8
4747
args: ["--ignore=E501,E731"]

examples/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ python examples/agri_example.py
1414

1515
In this demo all images are hardcoded.
1616

17+
## For husarion The Describer demo
18+
19+
Husarion is supposed to roam the environment and describe what it sees. \
20+
The observations are saved to map_database.txt in the format of `x: {}, y: {}, z: {}, observation"`
21+
The demo only works with bedrock for now.
22+
23+
```bash
24+
python examples/explore_and_describe_bedrock.py
25+
```
26+
1727
### Help
1828

1929
```bash
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
from typing import List, Type
3+
4+
from langchain_aws import ChatBedrock
5+
from langchain_core.language_models.chat_models import BaseChatModel
6+
from langchain_core.messages import SystemMessage
7+
from langchain_core.pydantic_v1 import BaseModel, Field
8+
from langchain_core.tools import BaseTool
9+
10+
from rai.communication.ros_communication import TF2TransformFetcher
11+
from rai.scenario_engine.messages import AgentLoop, HumanMultimodalMessage
12+
from rai.scenario_engine.scenario_engine import ScenarioPartType, ScenarioRunner
13+
from rai.scenario_engine.tool_runner import run_requested_tools
14+
from rai.tools.ros.cat_demo_tools import FinishTool
15+
from rai.tools.ros.cli import Ros2TopicTool, SetGoalPoseTool
16+
from rai.tools.ros.tools import (
17+
AddDescribedWaypointToDatabaseTool,
18+
GetCameraImageTool,
19+
GetOccupancyGridTool,
20+
)
21+
22+
23+
class DescribeAreaToolInput(BaseModel):
24+
"""Input for the describe_area tool."""
25+
26+
image_topic: str = Field(..., description="ROS2 image topic to subscribe to")
27+
28+
29+
class DescribeAreaTool(BaseTool):
30+
"""
31+
Describe the area. The tool uses the available tooling to describe the area around the robot.
32+
The output is saved to the map database.
33+
The tool does not return anything specific to the tool run.
34+
"""
35+
36+
name: str = "DescribeAreaTool"
37+
description: str = "A tool for describing the area around the robot."
38+
args_schema: Type[DescribeAreaToolInput] = DescribeAreaToolInput
39+
40+
llm: BaseChatModel # without tools
41+
system_message: SystemMessage
42+
map_database: str = ""
43+
44+
def _run(self, image_topic: str):
45+
get_camera_image_tool = GetCameraImageTool()
46+
set_waypoint_tool = AddDescribedWaypointToDatabaseTool(
47+
map_database=self.map_database
48+
)
49+
50+
current_position = TF2TransformFetcher().get_data()
51+
image = get_camera_image_tool.run(image_topic)["images"]
52+
llm_with_tools = self.llm.bind_tools([set_waypoint_tool]) # type: ignore
53+
human_message = HumanMultimodalMessage(
54+
content=f"Describe the area around the robot (area, not items). Reason how would you name the room you are currently in"
55+
f". Use available tooling. Your current position is: {current_position}",
56+
images=image,
57+
)
58+
messages = [self.system_message, human_message]
59+
ai_msg = llm_with_tools.invoke(messages)
60+
messages.append(ai_msg)
61+
run_requested_tools(
62+
ai_msg, [set_waypoint_tool], messages, llm_type="bedrock"
63+
) # TODO(@maciejmajek): fix hardcoded llm_type
64+
return "Description of the area completed."
65+
66+
67+
DESCRIBER_PROMPT = """
68+
You are an expert in describing the environment around you. Your main goal is to describe the area based on what you see in the image.
69+
"""
70+
71+
72+
def main():
73+
# setup database for the example
74+
if not os.path.exists("map_database.txt"):
75+
with open("map_database.txt", "w") as f:
76+
f.write("")
77+
78+
simple_llm = ChatBedrock(
79+
model_id="anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2" # type: ignore
80+
)
81+
tools = [
82+
GetOccupancyGridTool(),
83+
SetGoalPoseTool(),
84+
Ros2TopicTool(),
85+
DescribeAreaTool(
86+
llm=simple_llm,
87+
system_message=SystemMessage(content=DESCRIBER_PROMPT),
88+
map_database="map_database.txt",
89+
),
90+
FinishTool(),
91+
]
92+
93+
scenario: List[ScenarioPartType] = [
94+
SystemMessage(
95+
content="You are an autonomous agent. Your main goal is to fulfill the user's requests. "
96+
"Do not make assumptions about the environment you are currently in. "
97+
"Use the tooling provided to gather information about the environment. Remember to list available topics. "
98+
),
99+
HumanMultimodalMessage(
100+
content="Describe your surroundings and gather more information as needed. "
101+
"Move to explore further, aiming for clear areas near the robot (red arrow). Make sure to describe the area during movement."
102+
"Utilize available methods to obtain the map and identify relevant data streams. "
103+
"Before starting the exploration, find out what kind of tooling is available and based on that plan your exploration. For description, use the available tooling."
104+
),
105+
AgentLoop(stop_action=FinishTool().__class__.__name__, stop_iters=50),
106+
]
107+
108+
llm = ChatBedrock(model_id="anthropic.claude-3-5-sonnet-20240620-v1:0", region_name="us-east-1") # type: ignore
109+
runner = ScenarioRunner(scenario, llm=llm, tools=tools, llm_type="bedrock")
110+
runner.run()
111+
runner.save_to_html()
112+
113+
114+
if __name__ == "__main__":
115+
main()

examples/husarion_poc_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
from rai.scenario_engine.scenario_engine import ScenarioPartType, ScenarioRunner
99
from rai.tools.hmi_tools import PlayVoiceMessageTool, WaitForSecondsTool
1010
from rai.tools.ros.cat_demo_tools import FinishTool
11-
from rai.tools.ros.cli_tools import (
11+
from rai.tools.ros.cli import (
1212
Ros2InterfaceTool,
1313
Ros2ServiceTool,
1414
Ros2TopicTool,
1515
SetGoalPoseTool,
1616
)
1717
from rai.tools.ros.tools import (
18+
AddDescribedWaypointToDatabaseTool,
1819
GetCameraImageTool,
1920
GetCurrentPositionTool,
2021
GetOccupancyGridTool,
21-
SetWaypointTool,
2222
)
2323

2424

@@ -32,7 +32,7 @@ def main():
3232
Ros2ServiceTool(),
3333
Ros2InterfaceTool(),
3434
SetGoalPoseTool(),
35-
SetWaypointTool(),
35+
AddDescribedWaypointToDatabaseTool(),
3636
GetCurrentPositionTool(),
3737
FinishTool(),
3838
]

src/rai/communication/ros_communication.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@
77
import rclpy
88
import rclpy.qos
99
from cv_bridge import CvBridge
10+
from rclpy.duration import Duration
1011
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
1112
from rclpy.node import Node
12-
from rclpy.qos import QoSProfile
13+
from rclpy.qos import (
14+
QoSDurabilityPolicy,
15+
QoSHistoryPolicy,
16+
QoSLivelinessPolicy,
17+
QoSProfile,
18+
QoSReliabilityPolicy,
19+
)
1320
from rclpy.signals import SignalHandlerGuardCondition
1421
from rclpy.utilities import timeout_sec_to_nsec
1522
from sensor_msgs.msg import Image
23+
from tf2_ros import Buffer, TransformListener
1624

1725

1826
def wait_for_message(
@@ -91,7 +99,18 @@ def grab_message(self) -> Any:
9199

92100
node = rclpy.create_node(self.__class__.__name__ + "_node") # type: ignore
93101
qos_profile = rclpy.qos.qos_profile_sensor_data
94-
102+
if (
103+
self.topic == "/map"
104+
): # overfitting to husarion TODO(maciejmajek): find a better way
105+
qos_profile = QoSProfile(
106+
reliability=QoSReliabilityPolicy.RELIABLE,
107+
history=QoSHistoryPolicy.KEEP_ALL,
108+
durability=QoSDurabilityPolicy.TRANSIENT_LOCAL,
109+
lifespan=Duration(seconds=0),
110+
deadline=Duration(seconds=0),
111+
liveliness=QoSLivelinessPolicy.AUTOMATIC,
112+
liveliness_lease_duration=Duration(seconds=0),
113+
)
95114
success, msg = wait_for_message(
96115
self.message_type,
97116
node,
@@ -197,3 +216,43 @@ def get_data(self):
197216
command = "ros2 action list"
198217
output = subprocess.check_output(command, shell=True).decode("utf-8")
199218
return output
219+
220+
221+
class TF2Listener(Node):
222+
def __init__(self):
223+
super().__init__("tf2_listener")
224+
225+
# Create a buffer and listener
226+
self.tf_buffer = Buffer()
227+
self.tf_listener = TransformListener(self.tf_buffer, self)
228+
229+
# This will store the transform when received
230+
self.transform = None
231+
232+
def get_transform(self):
233+
try:
234+
# Lookup transform between base_link and map
235+
now = rclpy.time.Time()
236+
self.transform = self.tf_buffer.lookup_transform("map", "base_link", now)
237+
except Exception as ex:
238+
self.get_logger().debug(f"Could not transform: {ex}")
239+
240+
241+
class TF2TransformFetcher:
242+
def get_data(self):
243+
rclpy.init()
244+
node = TF2Listener()
245+
executor = rclpy.executors.SingleThreadedExecutor()
246+
executor.add_node(node)
247+
248+
try:
249+
while rclpy.ok() and node.transform is None:
250+
node.get_transform()
251+
rclpy.spin_once(node, timeout_sec=1.0)
252+
except KeyboardInterrupt:
253+
pass
254+
255+
transform = node.transform
256+
node.destroy_node()
257+
rclpy.shutdown()
258+
return transform

src/rai/scenario_engine/scenario_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ def _run(self, scenario: ScenarioType):
151151
f"Looping agent actions until {msg.stop_action}. Max {msg.stop_iters} loops."
152152
)
153153
for _ in range(msg.stop_iters):
154+
# if the last message is from the AI, we need to add a human message to continue the agent loop
155+
# otherwise the bedrock model will not be able to continue the conversation
156+
if self.history[-1].type == "ai":
157+
self.history.append(
158+
HumanMessage(
159+
content="Thank you. Please continue your mision using tools."
160+
)
161+
)
154162
ai_msg = cast(
155163
AIMessage,
156164
self.llm_with_tools.invoke(

src/rai/tools/ros/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
UseHonkTool,
55
UseLightsTool,
66
)
7-
from .cli_tools import Ros2InterfaceTool, Ros2ServiceTool, Ros2TopicTool
7+
from .cli import Ros2InterfaceTool, Ros2ServiceTool, Ros2TopicTool
88
from .mock_tools import (
99
ObserveSurroundingsTool,
1010
OpenSetSegmentationTool,
1111
VisualQuestionAnsweringTool,
1212
)
1313
from .tools import (
14+
AddDescribedWaypointToDatabaseTool,
1415
GetCameraImageTool,
1516
GetCurrentPositionTool,
1617
GetOccupancyGridTool,
17-
SetWaypointTool,
1818
)
1919

2020
__all__ = [
@@ -28,7 +28,7 @@
2828
"Ros2TopicTool",
2929
"Ros2InterfaceTool",
3030
"Ros2ServiceTool",
31-
"SetWaypointTool",
31+
"AddDescribedWaypointToDatabaseTool",
3232
"GetOccupancyGridTool",
3333
"GetCameraImageTool",
3434
"GetCurrentPositionTool",
File renamed without changes.

0 commit comments

Comments
 (0)