Skip to content

feat: husarion toolset #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ repos:
- id: black

- repo: https://github.com/pycqa/flake8
rev: "" # pick a git hash / tag to point to
rev: 7.1.0
hooks:
- id: flake8
args: ["--ignore=E501,E731"]
10 changes: 10 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ python examples/agri_example.py

In this demo all images are hardcoded.

## For husarion The Describer demo

Husarion is supposed to roam the environment and describe what it sees. \
The observations are saved to map_database.txt in the format of `x: {}, y: {}, z: {}, observation"`
The demo only works with bedrock for now.

```bash
python examples/explore_and_describe_bedrock.py
```

### Help

```bash
Expand Down
115 changes: 115 additions & 0 deletions examples/explore_and_describe_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
from typing import List, Type

from langchain_aws import ChatBedrock
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import SystemMessage
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool

from rai.communication.ros_communication import TF2TransformFetcher
from rai.scenario_engine.messages import AgentLoop, HumanMultimodalMessage
from rai.scenario_engine.scenario_engine import ScenarioPartType, ScenarioRunner
from rai.scenario_engine.tool_runner import run_requested_tools
from rai.tools.ros.cat_demo_tools import FinishTool
from rai.tools.ros.cli import Ros2TopicTool, SetGoalPoseTool
from rai.tools.ros.tools import (
AddDescribedWaypointToDatabaseTool,
GetCameraImageTool,
GetOccupancyGridTool,
)


class DescribeAreaToolInput(BaseModel):
"""Input for the describe_area tool."""

image_topic: str = Field(..., description="ROS2 image topic to subscribe to")


class DescribeAreaTool(BaseTool):
"""
Describe the area. The tool uses the available tooling to describe the area around the robot.
The output is saved to the map database.
The tool does not return anything specific to the tool run.
"""

name: str = "DescribeAreaTool"
description: str = "A tool for describing the area around the robot."
args_schema: Type[DescribeAreaToolInput] = DescribeAreaToolInput

llm: BaseChatModel # without tools
system_message: SystemMessage
map_database: str = ""

def _run(self, image_topic: str):
get_camera_image_tool = GetCameraImageTool()
set_waypoint_tool = AddDescribedWaypointToDatabaseTool(
map_database=self.map_database
)

current_position = TF2TransformFetcher().get_data()
image = get_camera_image_tool.run(image_topic)["images"]
llm_with_tools = self.llm.bind_tools([set_waypoint_tool]) # type: ignore
human_message = HumanMultimodalMessage(
content=f"Describe the area around the robot (area, not items). Reason how would you name the room you are currently in"
f". Use available tooling. Your current position is: {current_position}",
images=image,
)
messages = [self.system_message, human_message]
ai_msg = llm_with_tools.invoke(messages)
messages.append(ai_msg)
run_requested_tools(
ai_msg, [set_waypoint_tool], messages, llm_type="bedrock"
) # TODO(@maciejmajek): fix hardcoded llm_type
return "Description of the area completed."


DESCRIBER_PROMPT = """
You are an expert in describing the environment around you. Your main goal is to describe the area based on what you see in the image.
"""


def main():
# setup database for the example
if not os.path.exists("map_database.txt"):
with open("map_database.txt", "w") as f:
f.write("")

simple_llm = ChatBedrock(
model_id="anthropic.claude-3-haiku-20240307-v1:0", region_name="us-west-2" # type: ignore
)
tools = [
GetOccupancyGridTool(),
SetGoalPoseTool(),
Ros2TopicTool(),
DescribeAreaTool(
llm=simple_llm,
system_message=SystemMessage(content=DESCRIBER_PROMPT),
map_database="map_database.txt",
),
FinishTool(),
]

scenario: List[ScenarioPartType] = [
SystemMessage(
content="You are an autonomous agent. Your main goal is to fulfill the user's requests. "
"Do not make assumptions about the environment you are currently in. "
"Use the tooling provided to gather information about the environment. Remember to list available topics. "
),
HumanMultimodalMessage(
content="Describe your surroundings and gather more information as needed. "
"Move to explore further, aiming for clear areas near the robot (red arrow). Make sure to describe the area during movement."
"Utilize available methods to obtain the map and identify relevant data streams. "
"Before starting the exploration, find out what kind of tooling is available and based on that plan your exploration. For description, use the available tooling."
),
AgentLoop(stop_action=FinishTool().__class__.__name__, stop_iters=50),
]

llm = ChatBedrock(model_id="anthropic.claude-3-5-sonnet-20240620-v1:0", region_name="us-east-1") # type: ignore
runner = ScenarioRunner(scenario, llm=llm, tools=tools, llm_type="bedrock")
runner.run()
runner.save_to_html()


if __name__ == "__main__":
main()
6 changes: 3 additions & 3 deletions examples/husarion_poc_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
from rai.scenario_engine.scenario_engine import ScenarioPartType, ScenarioRunner
from rai.tools.hmi_tools import PlayVoiceMessageTool, WaitForSecondsTool
from rai.tools.ros.cat_demo_tools import FinishTool
from rai.tools.ros.cli_tools import (
from rai.tools.ros.cli import (
Ros2InterfaceTool,
Ros2ServiceTool,
Ros2TopicTool,
SetGoalPoseTool,
)
from rai.tools.ros.tools import (
AddDescribedWaypointToDatabaseTool,
GetCameraImageTool,
GetCurrentPositionTool,
GetOccupancyGridTool,
SetWaypointTool,
)


Expand All @@ -32,7 +32,7 @@ def main():
Ros2ServiceTool(),
Ros2InterfaceTool(),
SetGoalPoseTool(),
SetWaypointTool(),
AddDescribedWaypointToDatabaseTool(),
GetCurrentPositionTool(),
FinishTool(),
]
Expand Down
63 changes: 61 additions & 2 deletions src/rai/communication/ros_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@
import rclpy
import rclpy.qos
from cv_bridge import CvBridge
from rclpy.duration import Duration
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy
from rclpy.node import Node
from rclpy.qos import QoSProfile
from rclpy.qos import (
QoSDurabilityPolicy,
QoSHistoryPolicy,
QoSLivelinessPolicy,
QoSProfile,
QoSReliabilityPolicy,
)
from rclpy.signals import SignalHandlerGuardCondition
from rclpy.utilities import timeout_sec_to_nsec
from sensor_msgs.msg import Image
from tf2_ros import Buffer, TransformListener


def wait_for_message(
Expand Down Expand Up @@ -91,7 +99,18 @@ def grab_message(self) -> Any:

node = rclpy.create_node(self.__class__.__name__ + "_node") # type: ignore
qos_profile = rclpy.qos.qos_profile_sensor_data

if (
self.topic == "/map"
): # overfitting to husarion TODO(maciejmajek): find a better way
qos_profile = QoSProfile(
reliability=QoSReliabilityPolicy.RELIABLE,
history=QoSHistoryPolicy.KEEP_ALL,
durability=QoSDurabilityPolicy.TRANSIENT_LOCAL,
lifespan=Duration(seconds=0),
deadline=Duration(seconds=0),
liveliness=QoSLivelinessPolicy.AUTOMATIC,
liveliness_lease_duration=Duration(seconds=0),
)
success, msg = wait_for_message(
self.message_type,
node,
Expand Down Expand Up @@ -197,3 +216,43 @@ def get_data(self):
command = "ros2 action list"
output = subprocess.check_output(command, shell=True).decode("utf-8")
return output


class TF2Listener(Node):
def __init__(self):
super().__init__("tf2_listener")

# Create a buffer and listener
self.tf_buffer = Buffer()
self.tf_listener = TransformListener(self.tf_buffer, self)

# This will store the transform when received
self.transform = None

def get_transform(self):
try:
# Lookup transform between base_link and map
now = rclpy.time.Time()
self.transform = self.tf_buffer.lookup_transform("map", "base_link", now)
except Exception as ex:
self.get_logger().debug(f"Could not transform: {ex}")


class TF2TransformFetcher:
def get_data(self):
rclpy.init()
node = TF2Listener()
executor = rclpy.executors.SingleThreadedExecutor()
executor.add_node(node)

try:
while rclpy.ok() and node.transform is None:
node.get_transform()
rclpy.spin_once(node, timeout_sec=1.0)
except KeyboardInterrupt:
pass

transform = node.transform
node.destroy_node()
rclpy.shutdown()
return transform
8 changes: 8 additions & 0 deletions src/rai/scenario_engine/scenario_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ def _run(self, scenario: ScenarioType):
f"Looping agent actions until {msg.stop_action}. Max {msg.stop_iters} loops."
)
for _ in range(msg.stop_iters):
# if the last message is from the AI, we need to add a human message to continue the agent loop
# otherwise the bedrock model will not be able to continue the conversation
if self.history[-1].type == "ai":
self.history.append(
HumanMessage(
content="Thank you. Please continue your mision using tools."
)
)
ai_msg = cast(
AIMessage,
self.llm_with_tools.invoke(
Expand Down
6 changes: 3 additions & 3 deletions src/rai/tools/ros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
UseHonkTool,
UseLightsTool,
)
from .cli_tools import Ros2InterfaceTool, Ros2ServiceTool, Ros2TopicTool
from .cli import Ros2InterfaceTool, Ros2ServiceTool, Ros2TopicTool
from .mock_tools import (
ObserveSurroundingsTool,
OpenSetSegmentationTool,
VisualQuestionAnsweringTool,
)
from .tools import (
AddDescribedWaypointToDatabaseTool,
GetCameraImageTool,
GetCurrentPositionTool,
GetOccupancyGridTool,
SetWaypointTool,
)

__all__ = [
Expand All @@ -28,7 +28,7 @@
"Ros2TopicTool",
"Ros2InterfaceTool",
"Ros2ServiceTool",
"SetWaypointTool",
"AddDescribedWaypointToDatabaseTool",
"GetOccupancyGridTool",
"GetCameraImageTool",
"GetCurrentPositionTool",
Expand Down
File renamed without changes.
Loading