diff --git a/src/rai_bench/rai_bench/examples/manipulation_o3de/main.py b/src/rai_bench/rai_bench/examples/manipulation_o3de/main.py index b9beb5470..e17690a47 100644 --- a/src/rai_bench/rai_bench/examples/manipulation_o3de/main.py +++ b/src/rai_bench/rai_bench/examples/manipulation_o3de/main.py @@ -39,6 +39,7 @@ ) from rai_sim.o3de.o3de_bridge import ( O3DEngineArmManipulationBridge, + O3DExROS2SimulationConfig, ) @@ -116,28 +117,30 @@ def run_benchmark(model_name: str, vendor: str, out_dir: str): # ] ### import ready scenarios - t_scenarios = trivial_scenarios( - configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger - ) + t_scenarios = trivial_scenarios(configs_dir=configs_dir, logger=bench_logger) # e_scenarios = easy_scenarios( - # configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger + # configs_dir=configs_dir, logger=bench_logger # ) # m_scenarios = medium_scenarios( - # configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger + # configs_dir=configs_dir, logger=bench_logger # ) # h_scenarios = hard_scenarios( - # configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger + # configs_dir=configs_dir, logger=bench_logger # ) # vh_scenarios = very_hard_scenarios( - # configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger + # configs_dir=configs_dir, logger=bench_logger # ) all_scenarios = t_scenarios + simulation_config = O3DExROS2SimulationConfig.load_config( + config_path=Path(connector_path) + ) o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger) try: # define benchamrk benchmark = ManipulationO3DEBenchmark( model_name=model_name, + simulation_config=simulation_config, simulation_bridge=o3de, scenarios=all_scenarios, logger=bench_logger, diff --git a/src/rai_bench/rai_bench/examples/manipulation_o3de/scenarios.py b/src/rai_bench/rai_bench/examples/manipulation_o3de/scenarios.py index 44a810e80..f3b6b7d03 100644 --- a/src/rai_bench/rai_bench/examples/manipulation_o3de/scenarios.py +++ b/src/rai_bench/rai_bench/examples/manipulation_o3de/scenarios.py @@ -14,9 +14,7 @@ import logging from pathlib import Path -from typing import List, Union - -from rclpy.impl.rcutils_logger import RcutilsLogger +from typing import List from rai_bench.manipulation_o3de.benchmark import ManipulationO3DEBenchmark, Scenario from rai_bench.manipulation_o3de.interfaces import Task @@ -27,16 +25,12 @@ PlaceCubesTask, PlaceObjectAtCoordTask, ) -from rai_sim.o3de.o3de_bridge import ( - O3DExROS2SimulationConfig, -) - -loggers_type = Union[RcutilsLogger, logging.Logger] +from rai_sim.simulation_bridge import SceneConfig def trivial_scenarios( - configs_dir: str, connector_path: str, logger: loggers_type | None -) -> List[Scenario[O3DExROS2SimulationConfig]]: + configs_dir: str, logger: logging.Logger | None +) -> List[Scenario]: """Packet of trivial scenarios. The grading is subjective. This packet contains easy variants of 'easy' tasks with minimalistic scenes setups(1 object). @@ -59,18 +53,17 @@ def trivial_scenarios( List[Scenario[O3DExROS2SimulationConfig]] list of trivial scenarios """ - simulation_configs_paths: List[str] = [ + scene_configs_paths: List[str] = [ configs_dir + "1a.yaml", configs_dir + "1rc.yaml", configs_dir + "1t.yaml", configs_dir + "1yc.yaml", configs_dir + "1carrot.yaml", ] - simulations_configs = [ - O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) - for path in simulation_configs_paths + scene_configs = [ + SceneConfig.load_base_config(Path(path)) for path in scene_configs_paths ] - # place object at coords + # place object at coordss place_obj_types = [ "apple", "carrot", @@ -88,8 +81,8 @@ def trivial_scenarios( ) easy_place_objects_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=place_object_tasks, - simulation_configs=simulations_configs, - simulation_configs_paths=simulation_configs_paths, + scene_configs=scene_configs, + scene_configs_paths=scene_configs_paths, ) # move objects to the left object_groups = [["carrot"], ["red_cube"], ["tomato"], ["yellow_cube"]] @@ -101,16 +94,14 @@ def trivial_scenarios( easy_move_to_left_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=move_to_left_tasks, - simulation_configs=simulations_configs, - simulation_configs_paths=simulation_configs_paths, + scene_configs=scene_configs, + scene_configs_paths=scene_configs_paths, ) return [*easy_move_to_left_scenarios, *easy_place_objects_scenarios] -def easy_scenarios( - configs_dir: str, connector_path: str, logger: loggers_type | None -) -> List[Scenario[O3DExROS2SimulationConfig]]: +def easy_scenarios(configs_dir: str, logger: logging.Logger | None) -> List[Scenario]: """Packet of easy scenarios. The grading is subjective. This packet contains easy variants of 'easy' tasks with scenes containg no more than 3 objects @@ -135,7 +126,7 @@ def easy_scenarios( List[Scenario[O3DExROS2SimulationConfig]] list of easy scenarios """ - simulation_configs_paths: List[str] = [ + scene_configs_paths: List[str] = [ configs_dir + "1a_1t.yaml", configs_dir + "1a_2bc.yaml", configs_dir + "1bc_1rc_1yc.yaml", @@ -147,9 +138,8 @@ def easy_scenarios( configs_dir + "2a_1bc.yaml", configs_dir + "1carrot_1t_1rc.yaml", ] - simulations_configs = [ - O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) - for path in simulation_configs_paths + scene_configs = [ + SceneConfig.load_base_config(Path(path)) for path in scene_configs_paths ] # place object at coords place_obj_types = [ @@ -170,8 +160,8 @@ def easy_scenarios( ) easy_place_objects_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=place_object_tasks, - simulation_configs=simulations_configs, - simulation_configs_paths=simulation_configs_paths, + scene_configs=scene_configs, + scene_configs_paths=scene_configs_paths, logger=logger, ) # move objects to the left @@ -190,16 +180,16 @@ def easy_scenarios( easy_move_to_left_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=move_to_left_tasks, - simulation_configs=simulations_configs, - simulation_configs_paths=simulation_configs_paths, + scene_configs=scene_configs, + scene_configs_paths=scene_configs_paths, ) # place cubes task = PlaceCubesTask(threshold_distance=0.2, logger=logger) easy_place_cubes_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=[task], - simulation_configs=simulations_configs, - simulation_configs_paths=simulation_configs_paths, + scene_configs=scene_configs, + scene_configs_paths=scene_configs_paths, ) return [ @@ -209,9 +199,7 @@ def easy_scenarios( ] -def medium_scenarios( - configs_dir: str, connector_path: str, logger: loggers_type | None -) -> List[Scenario[O3DExROS2SimulationConfig]]: +def medium_scenarios(configs_dir: str, logger: logging.Logger | None) -> List[Scenario]: """Packet of medium scenarios. The grading is subjective. This packet contains harder variants of 'easy' tasks with scenes containg 4-7 objects and easy variants of 'hard' tasks with scenes contating 2-3 objects @@ -239,7 +227,7 @@ def medium_scenarios( List[Scenario[O3DExROS2SimulationConfig]] list of easy scenarios """ - medium_simulation_configs_paths: List[str] = [ + medium_scene_configs_paths: List[str] = [ configs_dir + "1rc_2bc_3yc.yaml", configs_dir + "2carrots_2a.yaml", configs_dir + "2yc_1bc_1rc.yaml", @@ -249,7 +237,7 @@ def medium_scenarios( configs_dir + "2a_1c_2rc.yaml", ] - easy_simulation_configs_paths: List[str] = [ + easy_scene_configs_paths: List[str] = [ configs_dir + "1a_1t.yaml", configs_dir + "1a_2bc.yaml", configs_dir + "1bc_1rc_1yc.yaml", @@ -261,13 +249,11 @@ def medium_scenarios( configs_dir + "2a_1bc.yaml", configs_dir + "1carrot_1t_1rc.yaml", ] - medium_simulations_configs = [ - O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) - for path in medium_simulation_configs_paths + medium_scene_configs = [ + SceneConfig.load_base_config(Path(path)) for path in medium_scene_configs_paths ] - easy_simulations_configs = [ - O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) - for path in easy_simulation_configs_paths + easy_scene_configs = [ + SceneConfig.load_base_config(Path(path)) for path in easy_scene_configs_paths ] # move objects to the left object_groups = [ @@ -286,8 +272,8 @@ def medium_scenarios( move_to_left_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=move_to_left_tasks, - simulation_configs=medium_simulations_configs, - simulation_configs_paths=medium_simulation_configs_paths, + scene_configs=medium_scene_configs, + scene_configs_paths=medium_scene_configs_paths, logger=logger, ) @@ -295,8 +281,8 @@ def medium_scenarios( task = PlaceCubesTask(threshold_distance=0.1, logger=logger) easy_place_cubes_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=[task], - simulation_configs=medium_simulations_configs, - simulation_configs_paths=medium_simulation_configs_paths, + scene_configs=medium_scene_configs, + scene_configs_paths=medium_scene_configs_paths, logger=logger, ) @@ -312,8 +298,8 @@ def medium_scenarios( build_tower_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=build_tower_tasks, - simulation_configs=easy_simulations_configs, - simulation_configs_paths=easy_simulation_configs_paths, + scene_configs=easy_scene_configs, + scene_configs_paths=easy_scene_configs_paths, ) # group object task @@ -332,8 +318,8 @@ def medium_scenarios( group_object_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=group_object_tasks, - simulation_configs=easy_simulations_configs, - simulation_configs_paths=easy_simulation_configs_paths, + scene_configs=easy_scene_configs, + scene_configs_paths=easy_scene_configs_paths, ) return [ *move_to_left_scenarios, @@ -343,9 +329,7 @@ def medium_scenarios( ] -def hard_scenarios( - configs_dir: str, connector_path: str, logger: loggers_type | None -) -> List[Scenario[O3DExROS2SimulationConfig]]: +def hard_scenarios(configs_dir: str, logger: logging.Logger | None) -> List[Scenario]: """Packet of hard scenarios. The grading is subjective. This packet contains harder variants of 'easy' tasks with majority of scenes containg 8+ objects, Objects can be positioned in an unusual way, for example stacked. @@ -374,7 +358,7 @@ def hard_scenarios( List[Scenario[O3DExROS2SimulationConfig]] list of easy scenarios """ - medium_simulation_configs_paths: List[str] = [ + medium_scene_configs_paths: List[str] = [ configs_dir + "1rc_2bc_3yc.yaml", configs_dir + "2carrots_2a.yaml", configs_dir + "2yc_1bc_1rc.yaml", @@ -384,7 +368,7 @@ def hard_scenarios( configs_dir + "2a_1c_2rc.yaml", ] - hard_simulation_configs_paths: List[str] = [ + hard_scene_configs_paths: List[str] = [ configs_dir + "3carrots_1a_1t_2bc_2yc.yaml", configs_dir + "1carrot_1a_2t_1bc_1rc_3yc_stacked.yaml", configs_dir + "2carrots_1a_1t_1bc_1rc_1yc_1corn.yaml", @@ -395,13 +379,11 @@ def hard_scenarios( configs_dir + "3carrots_1a_2bc_1rc_1yc_1corn.yaml", configs_dir + "3rc_3bc_stacked.yaml", ] - medium_simulations_configs = [ - O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) - for path in medium_simulation_configs_paths + medium_scene_configs = [ + SceneConfig.load_base_config(Path(path)) for path in medium_scene_configs_paths ] - hard_simulations_configs = [ - O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) - for path in hard_simulation_configs_paths + hard_scene_configs = [ + SceneConfig.load_base_config(Path(path)) for path in hard_scene_configs_paths ] # move objects to the left object_groups = [ @@ -420,16 +402,16 @@ def hard_scenarios( move_to_left_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=move_to_left_tasks, - simulation_configs=hard_simulations_configs, - simulation_configs_paths=hard_simulation_configs_paths, + scene_configs=hard_scene_configs, + scene_configs_paths=hard_scene_configs_paths, ) # place cubes task = PlaceCubesTask(threshold_distance=0.1, logger=logger) easy_place_cubes_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=[task], - simulation_configs=hard_simulations_configs, - simulation_configs_paths=hard_simulation_configs_paths, + scene_configs=hard_scene_configs, + scene_configs_paths=hard_scene_configs_paths, ) # build tower task @@ -444,9 +426,8 @@ def hard_scenarios( build_tower_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=build_tower_tasks, - simulation_configs=medium_simulations_configs, - simulation_configs_paths=medium_simulation_configs_paths, - logger=logger, + scene_configs=medium_scene_configs, + scene_configs_paths=medium_scene_configs_paths, ) # group object task @@ -466,8 +447,8 @@ def hard_scenarios( group_object_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=group_object_tasks, - simulation_configs=medium_simulations_configs, - simulation_configs_paths=medium_simulation_configs_paths, + scene_configs=medium_scene_configs, + scene_configs_paths=medium_scene_configs_paths, ) return [ *move_to_left_scenarios, @@ -478,8 +459,8 @@ def hard_scenarios( def very_hard_scenarios( - configs_dir: str, connector_path: str, logger: loggers_type | None -) -> List[Scenario[O3DExROS2SimulationConfig]]: + configs_dir: str, logger: logging.Logger | None +) -> List[Scenario]: """Packet of very_hard scenarios. The grading is subjective. This packet contains harder variants of 'hard' tasks with majority of scenes containg 8+ objects, Objects can be positioned in an unusual way, for example stacked. @@ -504,7 +485,7 @@ def very_hard_scenarios( List[Scenario[O3DExROS2SimulationConfig]] list of easy scenarios """ - hard_simulation_configs_paths: List[str] = [ + hard_scene_configs_paths: List[str] = [ configs_dir + "3carrots_1a_1t_2bc_2yc.yaml", configs_dir + "1carrot_1a_2t_1bc_1rc_3yc_stacked.yaml", configs_dir + "2carrots_1a_1t_1bc_1rc_1yc_1corn.yaml", @@ -515,9 +496,8 @@ def very_hard_scenarios( configs_dir + "3carrots_1a_2bc_1rc_1yc_1corn.yaml", configs_dir + "3rc_3bc_stacked.yaml", ] - hard_simulations_configs = [ - O3DExROS2SimulationConfig.load_config(Path(path), Path(connector_path)) - for path in hard_simulation_configs_paths + hard_scene_configs = [ + SceneConfig.load_base_config(Path(path)) for path in hard_scene_configs_paths ] # build tower task object_groups = [ @@ -536,8 +516,8 @@ def very_hard_scenarios( build_tower_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=build_tower_tasks, - simulation_configs=hard_simulations_configs, - simulation_configs_paths=hard_simulation_configs_paths, + scene_configs=hard_scene_configs, + scene_configs_paths=hard_scene_configs_paths, logger=logger, ) @@ -557,8 +537,8 @@ def very_hard_scenarios( group_object_scenarios = ManipulationO3DEBenchmark.create_scenarios( tasks=group_object_tasks, - simulation_configs=hard_simulations_configs, - simulation_configs_paths=hard_simulation_configs_paths, + scene_configs=hard_scene_configs, + scene_configs_paths=hard_scene_configs_paths, ) return [ *build_tower_scenarios, diff --git a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py index be3e06340..c737d5aa6 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/benchmark.py @@ -15,20 +15,30 @@ import statistics import time from pathlib import Path -from typing import Generic, List, TypeVar +from typing import List, TypeVar from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langgraph.errors import GraphRecursionError from langgraph.graph.state import CompiledStateGraph +from launch import LaunchDescription +from launch.actions import ( + IncludeLaunchDescription, +) +from launch.launch_description_sources import PythonLaunchDescriptionSource +from launch_ros.actions import Node +from launch_ros.substitutions import FindPackageShare from rai.messages import HumanMultimodalMessage from rai_bench.base_benchmark import BaseBenchmark, BenchmarkSummary from rai_bench.manipulation_o3de.interfaces import Task from rai_bench.manipulation_o3de.results_tracking import ScenarioResult +from rai_sim.o3de.o3de_bridge import ( + O3DEngineArmManipulationBridge, + O3DExROS2SimulationConfig, +) from rai_sim.simulation_bridge import ( Entity, - SimulationBridge, - SimulationConfigT, + SceneConfig, ) EntityT = TypeVar("EntityT", bound=Entity) @@ -38,7 +48,7 @@ class EntitiesMismatchException(Exception): pass -class Scenario(Generic[SimulationConfigT]): +class Scenario: """ A Scenario are defined by a pair of Task and Simlation Config. Each Scenario is executed separatly by a Benchmark. @@ -47,8 +57,8 @@ class Scenario(Generic[SimulationConfigT]): def __init__( self, task: Task, - simulation_config: SimulationConfigT, - simulation_config_path: str, + scene_config: SceneConfig, + scene_config_path: str, ) -> None: """ Initialize a Scenario. @@ -57,21 +67,21 @@ def __init__( ---------- task : Task The task to be executed. - simulation_config : SimulationConfigT - The simulation configuration for the scenario. - simulation_config_path : str - The file path to the simulation configuration. + scene_config : SceneConfig + The scene configuration for the scenario. + scene_config_path : str + The file path to the scene configuration. Raises ------ ValueError - If the provided simulation configuration is not valid for the task. + If the provided scene configuration is not valid for the task. """ self.task = task - self.simulation_config = simulation_config + self.scene_config = scene_config # NOTE (jmatejcz) needed for logging which config was used, # there probably is better way to do it - self.simulation_config_path = simulation_config_path + self.scene_config_path = scene_config_path class ManipulationO3DEBenchmark(BaseBenchmark): @@ -84,8 +94,9 @@ class ManipulationO3DEBenchmark(BaseBenchmark): def __init__( self, model_name: str, - simulation_bridge: SimulationBridge[SimulationConfigT], - scenarios: List[Scenario[SimulationConfigT]], + simulation_bridge: O3DEngineArmManipulationBridge, + simulation_config: O3DExROS2SimulationConfig, + scenarios: List[Scenario], results_dir: Path, logger: logging.Logger | None = None, ) -> None: @@ -95,20 +106,61 @@ def __init__( logger=logger, ) self.simulation_bridge = simulation_bridge + self.simulation_bridge.init_simulation(simulation_config=simulation_config) + self.simulation_bridge.launch_robotic_stack( + required_robotic_ros2_interfaces=simulation_config.required_robotic_ros2_interfaces, + launch_description=self.launch_description, + ) self.num_of_scenarios = len(scenarios) self.scenarios = enumerate(iter(scenarios)) self.scenario_results: List[ScenarioResult] = [] self.csv_initialize(self.results_filename, ScenarioResult) + @property + def launch_description(self): + launch_moveit = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + "src/examples/rai-manipulation-demo/Project/Examples/panda_moveit_config_demo.launch.py", + ] + ) + ) + + launch_robotic_manipulation = Node( + package="robotic_manipulation", + executable="robotic_manipulation", + output="screen", + parameters=[ + {"use_sim_time": True}, + ], + ) + + launch_openset = IncludeLaunchDescription( + PythonLaunchDescriptionSource( + [ + FindPackageShare("rai_bringup"), + "/launch/openset.launch.py", + ] + ), + ) + + return LaunchDescription( + [ + launch_openset, + launch_moveit, + launch_robotic_manipulation, + ] + ) + @classmethod def create_scenarios( cls, tasks: List[Task], - simulation_configs: List[SimulationConfigT], - simulation_configs_paths: List[str], + scene_configs: List[SceneConfig], + scene_configs_paths: List[str], logger: logging.Logger | None = None, - ) -> List[Scenario[SimulationConfigT]]: + ) -> List[Scenario]: """ Create scenarios by pairing each task with each suitable simulation configuration. @@ -128,22 +180,22 @@ def create_scenarios( """ # NOTE (jmatejcz) hacky_fix, taking paths as args here, not the best solution, # but more changes to code would be required - scenarios: List[Scenario[SimulationConfigT]] = [] + scenarios: List[Scenario] = [] if not logger: logger = logging.getLogger(__name__) for task in tasks: - for sim_conf, sim_path in zip(simulation_configs, simulation_configs_paths): - if task.validate_config(simulation_config=sim_conf): + for scene_conf, scene_path in zip(scene_configs, scene_configs_paths): + if task.validate_config(simulation_config=scene_conf): scenarios.append( Scenario( task=task, - simulation_config=sim_conf, - simulation_config_path=sim_path, + scene_config=scene_conf, + scene_config_path=scene_path, ) ) else: logger.debug( - f"Simulation config: {sim_path} is not suitable for task: {task.task_prompt}" + f"Simulation config: {scene_path} is not suitable for task: {task.task_prompt}" ) return scenarios @@ -162,12 +214,12 @@ def run_next(self, agent: CompiledStateGraph) -> None: try: i, scenario = next(self.scenarios) # Get the next scenario - self.simulation_bridge.setup_scene(scenario.simulation_config) + self.simulation_bridge.setup_scene(scenario.scene_config) self.logger.info( "======================================================================================" ) self.logger.info( - f"RUNNING SCENARIO NUMBER {i + 1} / {self.num_of_scenarios}\n TASK: {scenario.task.task_prompt}\n SIMULATION_CONFIG: {scenario.simulation_config_path}" + f"RUNNING SCENARIO NUMBER {i + 1} / {self.num_of_scenarios}\n TASK: {scenario.task.task_prompt}\n SIMULATION_CONFIG: {scenario.scene_config_path}" ) tool_calls_num = 0 @@ -217,7 +269,7 @@ def run_next(self, agent: CompiledStateGraph) -> None: scenario_result = ScenarioResult( task_prompt=scenario.task.task_prompt, system_prompt=scenario.task.system_prompt, - simulation_config_path=scenario.simulation_config_path, + scene_config_path=scenario.scene_config_path, model_name=self.model_name, score=score, total_time=total_time, diff --git a/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py b/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py index 0d09a2135..91e595320 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/interfaces.py @@ -22,8 +22,8 @@ from rai_sim.simulation_bridge import ( Entity, + SceneConfig, SimulationBridge, - SimulationConfig, SimulationConfigT, SpawnedEntity, ) @@ -68,7 +68,7 @@ def task_prompt(self) -> str: pass @abstractmethod - def validate_config(self, simulation_config: SimulationConfig) -> bool: + def validate_config(self, simulation_config: SceneConfig) -> bool: """Task should be able to verify if given config is suitable for specific task Args: @@ -389,9 +389,7 @@ def system_prompt(self) -> str: """ @abstractmethod - def check_if_required_objects_present( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: """ Check if the required objects are present in the simulation configuration. @@ -402,9 +400,7 @@ def check_if_required_objects_present( """ return True - def check_if_any_placed_incorrectly( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_any_placed_incorrectly(self, simulation_config: SceneConfig) -> bool: """ Check if any object is placed incorrectly in the simulation configuration. Save number of initially correctly and incorrectly placed objects for @@ -418,7 +414,7 @@ def check_if_any_placed_incorrectly( _, incorrect = self.calculate_correct(entities=simulation_config.entities) return incorrect > 0 - def validate_config(self, simulation_config: SimulationConfig) -> bool: + def validate_config(self, simulation_config: SceneConfig) -> bool: """ Validate the simulation configuration. @@ -490,7 +486,7 @@ def calculate_current_placements( return current_correct, current_incorrect def calculate_score( - self, simulation_bridge: SimulationBridge[SimulationConfig] + self, simulation_bridge: SimulationBridge[SceneConfig] ) -> float: """ Calculate the task score based on the difference between initial and current placements. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/results_tracking.py b/src/rai_bench/rai_bench/manipulation_o3de/results_tracking.py index 02763e90a..eb278d3cb 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/results_tracking.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/results_tracking.py @@ -21,8 +21,8 @@ class ScenarioResult(BaseModel): task_prompt: str = Field(..., description="The task prompt.") system_prompt: str = Field(..., description="The system prompt.") model_name: str = Field(..., description="Name of the LLM.") - simulation_config_path: str = Field( - ..., description="Path to the simulation configuration file." + scene_config_path: str = Field( + ..., description="Path to the scene configuration file." ) score: float = Field( ..., description="Value between 0 and 1, describing the task score." diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py index e64ff6428..771b99387 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/build_tower_task.py @@ -20,7 +20,7 @@ from rai_bench.manipulation_o3de.interfaces import ( ManipulationTask, ) -from rai_sim.simulation_bridge import Entity, SimulationConfig +from rai_sim.simulation_bridge import Entity, SceneConfig loggers_type = Union[RcutilsLogger, logging.Logger] @@ -80,9 +80,7 @@ def task_prompt(self) -> str: cube_names = ", ".join(obj + "s" for obj in self.obj_types).replace("_", " ") return f"Manipulate objects so that all {cube_names} form a single vertical tower. Other types of objects cannot be included in a tower." - def check_if_required_objects_present( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: """ Validate that at least two cubes of the specified types are present. diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py index be0ece326..41ce103ec 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/group_objects_task.py @@ -19,7 +19,7 @@ from rai_bench.manipulation_o3de.interfaces import ( ManipulationTask, ) -from rai_sim.simulation_bridge import Entity, SimulationConfig +from rai_sim.simulation_bridge import Entity, SceneConfig loggers_type = Union[RcutilsLogger, logging.Logger] @@ -61,9 +61,7 @@ def task_prompt(self) -> str: "4. Be completely separated from other clusters " ) - def check_if_required_objects_present( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: """ Returns ------- diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py index 53475ff76..b6032ec1b 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/move_object_to_left_task.py @@ -19,7 +19,7 @@ from rai_bench.manipulation_o3de.interfaces import ( ManipulationTask, ) -from rai_sim.simulation_bridge import Entity, SimulationConfig +from rai_sim.simulation_bridge import Entity, SceneConfig loggers_type = Union[RcutilsLogger, logging.Logger] @@ -44,9 +44,7 @@ def task_prompt(self) -> str: # but 'left side' is depending on where camera is positioned so it might not be enough return f"Manipulate objects, so that all of the {obj_names} are on the left side of the table (positive y)" - def check_if_required_objects_present( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: """Validate if any object present""" object_types_present = self.group_entities_by_type( entities=simulation_config.entities diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py index e35887ba8..cd0f9fc28 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_at_coord_task.py @@ -21,7 +21,7 @@ from rai_bench.manipulation_o3de.interfaces import ( ManipulationTask, ) -from rai_sim.simulation_bridge import Entity, SimulationConfig +from rai_sim.simulation_bridge import Entity, SceneConfig loggers_type = Union[RcutilsLogger, logging.Logger] @@ -61,9 +61,7 @@ def task_prompt(self) -> str: f"the coordinates (x: {x}, y: {y})." ) - def check_if_required_objects_present( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: count = sum( 1 for ent in simulation_config.entities if ent.prefab_name == self.obj_type ) diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py index 1cad2c872..5ec7c5d1b 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/place_cubes_task.py @@ -19,7 +19,7 @@ from rai_bench.manipulation_o3de.interfaces import ( ManipulationTask, ) -from rai_sim.simulation_bridge import Entity, SimulationConfig +from rai_sim.simulation_bridge import Entity, SceneConfig loggers_type = Union[RcutilsLogger, logging.Logger] @@ -49,9 +49,7 @@ def __init__( def task_prompt(self) -> str: return "Manipulate objects, so that all cubes are adjacent to at least one cube" - def check_if_required_objects_present( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: """ Returns ------- diff --git a/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py b/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py index 0ec58a4fb..e91af00c8 100644 --- a/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py +++ b/src/rai_bench/rai_bench/manipulation_o3de/tasks/rotate_object_task.py @@ -22,7 +22,7 @@ from rai_bench.manipulation_o3de.interfaces import ( ManipulationTask, ) -from rai_sim.simulation_bridge import Entity, SimulationConfig +from rai_sim.simulation_bridge import Entity, SceneConfig loggers_type = Union[RcutilsLogger, logging.Logger] @@ -58,9 +58,7 @@ def task_prompt(self) -> str: "Remember to rotate the gripper when grabbing objects." ) - def check_if_required_objects_present( - self, simulation_config: SimulationConfig - ) -> bool: + def check_if_required_objects_present(self, simulation_config: SceneConfig) -> bool: """ Validate that at least one object of the specified types is present. diff --git a/src/rai_bench/rai_bench/results_processing/data_loading.py b/src/rai_bench/rai_bench/results_processing/data_loading.py index 2cceb065d..dc45df5c7 100644 --- a/src/rai_bench/rai_bench/results_processing/data_loading.py +++ b/src/rai_bench/rai_bench/results_processing/data_loading.py @@ -101,7 +101,7 @@ def convert_row_to_scenario_result(row: pd.Series) -> ScenarioResult: task_prompt=row["task_prompt"], system_prompt=row["system_prompt"], model_name=row["model_name"], - simulation_config_path=row["simulation_config_path"], + scene_config_path=row["scene_config_path"], score=float(row["score"]), total_time=float(row["total_time"]), number_of_tool_calls=int(row["number_of_tool_calls"]), diff --git a/src/rai_sim/rai_sim/launch_manager.py b/src/rai_sim/rai_sim/launch_manager.py new file mode 100644 index 000000000..abee90bd2 --- /dev/null +++ b/src/rai_sim/rai_sim/launch_manager.py @@ -0,0 +1,65 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import multiprocessing +from multiprocessing.synchronize import Event +from typing import Optional + +from launch import LaunchDescription, LaunchService + + +class ROS2LaunchManager: + def __init__(self) -> None: + self._stop_event: Optional[Event] = None + self._process: Optional[multiprocessing.Process] = None + + def start(self, launch_description: LaunchDescription) -> None: + self._stop_event = multiprocessing.Event() + self._process = multiprocessing.Process( + target=self._run_process, + args=(self._stop_event, launch_description), + daemon=True, + ) + self._process.start() + + def shutdown(self) -> None: + if self._stop_event: + self._stop_event.set() + if self._process: + self._process.join() + + def _run_process( + self, stop_event: Event, launch_description: LaunchDescription + ) -> None: + loop = asyncio.get_event_loop() + asyncio.set_event_loop(loop) + launch_service = LaunchService() + launch_service.include_launch_description(launch_description) + # launch description launched + launch_task = loop.create_task(launch_service.run_async()) + # when stop event set + loop.run_until_complete(loop.run_in_executor(None, stop_event.wait)) + if not launch_task.done(): + # XXX (jmatejcz) the shutdown function sends shutdown signal to all + # nodes launch with launch description which should do the trick + # but some nodes are stubborn and there is a possibility + # that they don't close. If this will happen sending PKILL for all + # ros nodes will be needed + shutdown_task = loop.create_task( + launch_service.shutdown(), + ) + # shutdown task should complete when all nodes are closed + # but wait also for launch task to close just to be sure + loop.run_until_complete(asyncio.gather(shutdown_task, launch_task)) diff --git a/src/rai_sim/rai_sim/o3de/o3de_bridge.py b/src/rai_sim/rai_sim/o3de/o3de_bridge.py index f69a83b50..dac0e1027 100644 --- a/src/rai_sim/rai_sim/o3de/o3de_bridge.py +++ b/src/rai_sim/rai_sim/o3de/o3de_bridge.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -import shlex import signal import subprocess import time @@ -23,6 +22,7 @@ import yaml from geometry_msgs.msg import Pose as ROS2Pose from geometry_msgs.msg import PoseStamped as ROS2PoseStamped +from launch import LaunchDescription from rai.communication.ros2 import ROS2Connector, ROS2Message from rai.communication.ros2.ros_async import get_future_result from rai.types import ( @@ -34,8 +34,10 @@ from tf2_geometry_msgs import do_transform_pose, do_transform_pose_stamped from rai_interfaces.srv import ManipulatorMoveTo +from rai_sim.launch_manager import ROS2LaunchManager from rai_sim.simulation_bridge import ( Entity, + SceneConfig, SceneState, SimulationBridge, SimulationConfig, @@ -47,19 +49,16 @@ class O3DExROS2SimulationConfig(SimulationConfig): binary_path: Path level: Optional[str] = None - robotic_stack_command: str required_simulation_ros2_interfaces: dict[str, List[str]] required_robotic_ros2_interfaces: dict[str, List[str]] - @classmethod - def load_config( - cls, base_config_path: Path, connector_config_path: Path - ) -> "O3DExROS2SimulationConfig": - base_config = SimulationConfig.load_base_config(base_config_path) + model_config = {"arbitrary_types_allowed": True} - with open(connector_config_path) as f: + @classmethod + def load_config(cls, config_path: Path) -> "O3DExROS2SimulationConfig": + with open(config_path) as f: connector_content: dict[str, Any] = yaml.safe_load(f) - return cls(**base_config.model_dump(), **connector_content) + return cls(**connector_content) class O3DExROS2Bridge(SimulationBridge[O3DExROS2SimulationConfig]): @@ -68,10 +67,17 @@ def __init__( ): super().__init__(logger=logger) self.connector = connector + self.manager = ROS2LaunchManager() self.current_sim_process = None - self.current_robotic_stack_process = None self.current_binary_path = None + def init_simulation(self, simulation_config: O3DExROS2SimulationConfig): + if self.current_binary_path != simulation_config.binary_path: + if self.current_sim_process: + self.shutdown() + self._launch_binary(simulation_config) + self.current_binary_path = simulation_config.binary_path + def shutdown(self): self._shutdown_binary() self._shutdown_robotic_stack() @@ -136,10 +142,7 @@ def _shutdown_binary(self): self.current_sim_process = None def _shutdown_robotic_stack(self): - self._shutdown_process( - process=self.current_robotic_stack_process, process_name="robotic_stack" - ) - self.current_robotic_stack_process = None + self.manager.shutdown() def get_available_spawnable_names(self) -> list[str]: msg = ROS2Message(payload={}) @@ -297,21 +300,13 @@ def _is_ros2_stack_ready( def setup_scene( self, - simulation_config: O3DExROS2SimulationConfig, + scene_config: SceneConfig, ): - if self.current_binary_path != simulation_config.binary_path: - if self.current_sim_process: - self.shutdown() - self._launch_binary(simulation_config) - self._launch_robotic_stack(simulation_config) - self.current_binary_path = simulation_config.binary_path + while self.spawned_entities: + self._despawn_entity(self.spawned_entities[0]) + self.logger.info(f"Entities after despawn: {self.spawned_entities}") - else: - while self.spawned_entities: - self._despawn_entity(self.spawned_entities[0]) - self.logger.info(f"Entities after despawn: {self.spawned_entities}") - - for entity in simulation_config.entities: + for entity in scene_config.entities: self._spawn_entity(entity) def _launch_binary( @@ -334,16 +329,15 @@ def _launch_binary( ): raise RuntimeError("ROS2 stack is not ready in time.") - def _launch_robotic_stack(self, simulation_config: O3DExROS2SimulationConfig): - command = shlex.split(simulation_config.robotic_stack_command) - self.logger.info(f"Running command: {command}") - self.current_robotic_stack_process = subprocess.Popen( - command, - ) - if not self._has_process_started(self.current_robotic_stack_process): - raise RuntimeError("Process did not start in time.") + def launch_robotic_stack( + self, + required_robotic_ros2_interfaces: dict[str, List[str]], + launch_description: LaunchDescription, + ): + self.manager.start(launch_description=launch_description) + if not self._is_ros2_stack_ready( - required_ros2_stack=simulation_config.required_robotic_ros2_interfaces + required_ros2_stack=required_robotic_ros2_interfaces ): raise RuntimeError("ROS2 stack is not ready in time.") diff --git a/src/rai_sim/rai_sim/simulation_bridge.py b/src/rai_sim/rai_sim/simulation_bridge.py index 5ef380b9c..bb514e270 100644 --- a/src/rai_sim/rai_sim/simulation_bridge.py +++ b/src/rai_sim/rai_sim/simulation_bridge.py @@ -70,7 +70,7 @@ class SpawnEntityService(ROS2BaseModel): xml: str = Field(default="") -class SimulationConfig(BaseModel): +class SceneConfig(BaseModel): """ Setup of simulation - arrangement of objects in the environment. @@ -110,7 +110,7 @@ def check_unique_names(cls, entities: List[Entity]) -> List[Entity]: return entities @classmethod - def load_base_config(cls, base_config_path: Path) -> "SimulationConfig": + def load_base_config(cls, base_config_path: Path) -> "SceneConfig": """ Loads a simulation configuration from a YAML file. @@ -164,6 +164,9 @@ class SceneState(BaseModel): ) +class SimulationConfig(BaseModel): ... + + SimulationConfigT = TypeVar("SimulationConfigT", bound=SimulationConfig) @@ -182,7 +185,15 @@ def __init__(self, logger: Optional[logging.Logger] = None): self.logger = logger @abstractmethod - def setup_scene(self, simulation_config: SimulationConfigT): + def init_simulation(self, simulation_config: SimulationConfigT): + """ + Initialize simulation binary and all other required processes, + for example ros2 nodes + """ + pass + + @abstractmethod + def setup_scene(self, scene_config: SceneConfig): """ Runs and sets up the simulation scene according to the provided configuration. diff --git a/tests/rai_sim/test_o3de_bridge.py b/tests/rai_sim/test_o3de_bridge.py index d077b6e9f..efad52207 100644 --- a/tests/rai_sim/test_o3de_bridge.py +++ b/tests/rai_sim/test_o3de_bridge.py @@ -14,6 +14,7 @@ import inspect import signal +import subprocess import typing import unittest from pathlib import Path @@ -22,6 +23,7 @@ import rclpy from geometry_msgs.msg import TransformStamped as ROS2TransformStamped +from launch import LaunchDescription from rai.communication.ros2 import ROS2Connector, ROS2Message from rai.types import ( Header, @@ -33,17 +35,15 @@ from rclpy.node import Node from rclpy.qos import QoSProfile +from rai_sim.launch_manager import ROS2LaunchManager from rai_sim.o3de.o3de_bridge import O3DExROS2Bridge, O3DExROS2SimulationConfig -from rai_sim.simulation_bridge import Entity, SpawnedEntity +from rai_sim.simulation_bridge import Entity, SceneConfig, SpawnedEntity -def test_load_config(sample_base_yaml_config: Path, sample_o3dexros2_config: Path): - config = O3DExROS2SimulationConfig.load_config( - sample_base_yaml_config, sample_o3dexros2_config - ) +def test_load_config(sample_o3dexros2_config: Path): + config = O3DExROS2SimulationConfig.load_config(sample_o3dexros2_config) assert isinstance(config, O3DExROS2SimulationConfig) assert config.binary_path == Path("/path/to/binary") - assert config.robotic_stack_command == "ros2 launch robotic_stack.launch.py" assert config.required_simulation_ros2_interfaces == { "services": ["/spawn_entity", "/delete_entity"], "topics": ["/color_image5", "/depth_image5", "/color_camera_info5"], @@ -58,6 +58,10 @@ def test_load_config(sample_base_yaml_config: Path, sample_o3dexros2_config: Pat "topics": [], "actions": ["/execute_trajectory"], } + + +def test_load_scene_config(sample_base_yaml_config: Path): + config = SceneConfig.load_base_config(sample_base_yaml_config) assert isinstance(config.entities, list) assert all(isinstance(e, Entity) for e in config.entities) @@ -65,9 +69,14 @@ def test_load_config(sample_base_yaml_config: Path, sample_o3dexros2_config: Pat class TestO3DExROS2Bridge(unittest.TestCase): - def setUp(self): + @patch("rai_sim.o3de.o3de_bridge.ROS2LaunchManager") + def setUp(self, mock_launch_manager_class): self.mock_connector = MagicMock(spec=ROS2Connector) self.mock_logger = MagicMock() + + self.mock_launch_manager = MagicMock(spec=ROS2LaunchManager) + mock_launch_manager_class.return_value = self.mock_launch_manager + self.bridge = O3DExROS2Bridge( connector=self.mock_connector, logger=self.mock_logger ) @@ -100,8 +109,6 @@ def setUp(self): self.test_config = O3DExROS2SimulationConfig( binary_path=Path("/path/to/binary"), - robotic_stack_command="ros2 launch robot.launch.py", - entities=[self.test_entity], required_simulation_ros2_interfaces={ "services": [], "topics": [], @@ -118,20 +125,34 @@ def test_init(self): self.assertEqual(self.bridge.connector, self.mock_connector) self.assertEqual(self.bridge.logger, self.mock_logger) self.assertIsNone(self.bridge.current_sim_process) - self.assertIsNone(self.bridge.current_robotic_stack_process) self.assertIsNone(self.bridge.current_binary_path) self.assertEqual(self.bridge.spawned_entities, []) - @patch("subprocess.Popen") - def test_launch_robotic_stack(self, mock_popen): - mock_process = MagicMock() - mock_process.poll.return_value = None - mock_process.pid = 54321 - mock_popen.return_value = mock_process - self.bridge._launch_robotic_stack(self.test_config) + def test_launch_robotic_stack(self): + mock_launch_description = MagicMock(spec=LaunchDescription) + + required_interfaces = { + "services": ["/test_service"], + "topics": ["/test_topic"], + "actions": ["/test_action"], + } + + self.bridge._is_ros2_stack_ready = MagicMock(return_value=True) + self.bridge.launch_robotic_stack(required_interfaces, mock_launch_description) - mock_popen.assert_called_once_with(["ros2", "launch", "robot.launch.py"]) - self.assertEqual(self.bridge.current_robotic_stack_process, mock_process) + self.mock_launch_manager.start.assert_called_once_with( + launch_description=mock_launch_description + ) + self.bridge._is_ros2_stack_ready.assert_called_once_with( + required_ros2_stack=required_interfaces + ) + + self.bridge._is_ros2_stack_ready.return_value = False + + with self.assertRaises(RuntimeError): + self.bridge.launch_robotic_stack( + required_interfaces, mock_launch_description + ) @patch("subprocess.Popen") def test_launch_binary(self, mock_popen): @@ -140,35 +161,106 @@ def test_launch_binary(self, mock_popen): mock_process.pid = 54322 mock_popen.return_value = mock_process + self.bridge._has_process_started = MagicMock(return_value=True) + self.bridge._is_ros2_stack_ready = MagicMock(return_value=True) + self.bridge._launch_binary(self.test_config) mock_popen.assert_called_once_with(["/path/to/binary"]) + self.assertEqual(self.bridge.current_sim_process, mock_process) - def test_shutdown_binary(self): - mock_process = MagicMock() - mock_process.poll.return_value = 0 + self.bridge._has_process_started.assert_called_once_with(process=mock_process) + self.bridge._is_ros2_stack_ready.assert_called_once() - self.bridge.current_sim_process = mock_process + def test_shutdown_process(self): + mock_process = MagicMock(spec=subprocess.Popen) + mock_process.pid = 12345 + process_name = "test_process" - self.bridge._shutdown_binary() + mock_process.wait.return_value = 0 + + self.bridge._shutdown_process(mock_process, process_name) mock_process.send_signal.assert_called_once_with(signal.SIGINT) mock_process.wait.assert_called_once() + mock_process.reset_mock() + + mock_process.wait.side_effect = [ + subprocess.TimeoutExpired(cmd="test", timeout=15), # SIGINT times out + 0, # SIGTERM succeeds + ] + + self.bridge._shutdown_process(mock_process, process_name) + + expected_calls = [ + unittest.mock.call(signal.SIGINT), + unittest.mock.call(signal.SIGTERM), + ] + self.assertEqual(mock_process.send_signal.call_args_list, expected_calls) + self.assertEqual(mock_process.wait.call_count, 2) + + mock_process.reset_mock() + + # Test case where both SIGINT and SIGTERM time out, requiring SIGKILL + mock_process.wait.side_effect = [ + subprocess.TimeoutExpired(cmd="test", timeout=15), # SIGINT times out + subprocess.TimeoutExpired(cmd="test", timeout=15), # SIGTERM times out + 0, # SIGKILL succeeds + ] + + self.bridge._shutdown_process(mock_process, process_name) + + expected_calls = [ + unittest.mock.call(signal.SIGINT), + unittest.mock.call(signal.SIGTERM), + ] + self.assertEqual(mock_process.send_signal.call_args_list, expected_calls) + self.assertEqual(mock_process.kill.call_count, 1) + self.assertEqual(mock_process.wait.call_count, 3) + + # Test case where process is None + self.bridge._shutdown_process(None, "nonexistent_process") + # No exceptions should be raised + + def test_shutdown_binary(self): + # Setup + mock_process = MagicMock(spec=subprocess.Popen) + self.bridge.current_sim_process = mock_process + + # Mock _shutdown_process + self.bridge._shutdown_process = MagicMock() + + # Call the method + self.bridge._shutdown_binary() + + # Verify _shutdown_process was called with the right parameters + self.bridge._shutdown_process.assert_called_once_with( + process=mock_process, process_name="binary" + ) + + # Verify current_sim_process was set to None self.assertIsNone(self.bridge.current_sim_process) def test_shutdown_robotic_stack(self): - mock_process = MagicMock() - mock_process.poll.return_value = 0 + # Call the method + self.bridge._shutdown_robotic_stack() - self.bridge.current_robotic_stack_process = mock_process + # Verify manager.shutdown was called + self.mock_launch_manager.shutdown.assert_called_once() - self.bridge._shutdown_robotic_stack() + def test_shutdown(self): + # Mock the component shutdown methods + self.bridge._shutdown_binary = MagicMock() + self.bridge._shutdown_robotic_stack = MagicMock() - mock_process.send_signal.assert_called_once_with(signal.SIGINT) - mock_process.wait.assert_called_once() - self.assertIsNone(self.bridge.current_robotic_stack_process) + # Call the method + self.bridge.shutdown() + + # Verify component shutdown methods were called + self.bridge._shutdown_binary.assert_called_once() + self.bridge._shutdown_robotic_stack.assert_called_once() def test_get_available_spawnable_names(self): # Mock the response diff --git a/tests/rai_sim/test_simulation_bridge.py b/tests/rai_sim/test_simulation_bridge.py index 2a05368cf..5e383fb85 100644 --- a/tests/rai_sim/test_simulation_bridge.py +++ b/tests/rai_sim/test_simulation_bridge.py @@ -29,6 +29,7 @@ from rai_sim.simulation_bridge import ( Entity, + SceneConfig, SceneState, SimulationBridge, SimulationConfig, @@ -114,13 +115,13 @@ def test_spawned_entity(pose: PoseStamped): assert spawned_entity.id == "id_123" -def test_simulation_config_unique_names(pose): +def test_scene_config_unique_names(pose: PoseStamped): entities = [ Entity(name="entity1", prefab_name="cube", pose=pose), Entity(name="entity2", prefab_name="carrot", pose=pose), ] - config = SimulationConfig(entities=entities) + config = SceneConfig(entities=entities) assert isinstance(config.entities, list) assert all(isinstance(e, Entity) for e in config.entities) @@ -128,18 +129,18 @@ def test_simulation_config_unique_names(pose): assert len(config.entities) == 2 -def test_simulation_config_duplicate_names(pose): +def test_scene_config_duplicate_names(pose: PoseStamped): entities = [ Entity(name="duplicate", prefab_name="cube", pose=pose), Entity(name="duplicate", prefab_name="carrot", pose=pose), ] with pytest.raises(ValidationError): - SimulationConfig(entities=entities) + SceneConfig(entities=entities) def test_load_base_config(sample_base_yaml_config: Path): - config = SimulationConfig.load_base_config(sample_base_yaml_config) + config = SceneConfig.load_base_config(sample_base_yaml_config) assert isinstance(config.entities, list) assert all(isinstance(e, Entity) for e in config.entities) @@ -150,9 +151,12 @@ def test_load_base_config(sample_base_yaml_config: Path): class MockSimulationBridge(SimulationBridge[SimulationConfig]): """Mock implementation of SimulationBridge for testing.""" - def setup_scene(self, simulation_config: SimulationConfig): + def init_simulation(self, simulation_config: SimulationConfig): + pass + + def setup_scene(self, scene_config: SceneConfig): """Mock implementation of setup_scene.""" - for entity in simulation_config.entities: + for entity in scene_config.entities: self._spawn_entity(entity) def _spawn_entity(self, entity: Entity): @@ -211,9 +215,7 @@ def setUp(self): ) # Create a test configuration - self.test_config = SimulationConfig( - entities=[self.test_entity1, self.test_entity2] - ) + self.test_config = SceneConfig(entities=[self.test_entity1, self.test_entity2]) def test_init(self): # Test with provided logger