Skip to content

feat: add missing logic rai_bench #595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/rai_bench/rai_bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .test_models import (
ManipulationO3DEBenchmarkConfig,
ToolCallingAgentBenchmarkConfig,
test_models,
)
from .utils import (
define_benchmark_logger,
get_llm_for_benchmark,
parse_manipulation_o3de_benchmark_args,
parse_tool_calling_benchmark_args,
)

__all__ = [
"ManipulationO3DEBenchmarkConfig",
"ToolCallingAgentBenchmarkConfig",
"define_benchmark_logger",
"get_llm_for_benchmark",
"parse_manipulation_o3de_benchmark_args",
"parse_tool_calling_benchmark_args",
"test_models",
]
35 changes: 26 additions & 9 deletions src/rai_bench/rai_bench/base_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,30 @@

import csv
import logging
import signal
import types
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Optional

from langgraph.graph.state import CompiledStateGraph
from pydantic import BaseModel, Field


class BenchmarkSummary(BaseModel):
class RunSummary(BaseModel):
model_name: str = Field(..., description="Name of the LLM.")
success_rate: float = Field(
..., description="Percentage of successfully completed tasks."
)
avg_time: float = Field(..., description="Average time taken across all tasks.")
total_extra_tool_calls_used: int = Field(
..., description="Total number of extra tool calls used in this Task"
)
total_tasks: int = Field(..., description="Total number of executed tasks.")


class TimeoutException(Exception):
pass


class BaseBenchmark(ABC):
"""Base class for all benchmarks."""

Expand Down Expand Up @@ -76,9 +81,7 @@ def csv_initialize(filename: Path, base_model_cls: type[BaseModel]) -> None:
Pydantic model class to be used for creating the columns in the CSV file.
"""
with open(filename, mode="w", newline="", encoding="utf-8") as file:
writer = csv.DictWriter(
file, fieldnames=base_model_cls.__annotations__.keys()
)
writer = csv.DictWriter(file, fieldnames=base_model_cls.model_fields.keys())
writer.writeheader()

@staticmethod
Expand All @@ -102,7 +105,7 @@ def csv_writerow(filename: Path, base_model_instance: BaseModel) -> None:

with open(filename, mode="a", newline="", encoding="utf-8") as file:
writer = csv.DictWriter(
file, fieldnames=base_model_instance.__annotations__.keys()
file, fieldnames=base_model_instance.model_fields.keys()
)
writer.writerow(row)

Expand All @@ -119,7 +122,21 @@ def run_next(self, agent: CompiledStateGraph) -> None:

@abstractmethod
def compute_and_save_summary(self) -> None:
# TODO (jmatejcz) this can be probably same for all benchmark in the future
"""Compute summary statistics and save them to the summary file."""
pass

# TODO (jm) this can be probably same for all benchmark in the future
@contextmanager
def time_limit(self, seconds: int):
def signal_handler(signum: int, frame: Optional[types.FrameType]):
raise TimeoutException(f"Timed out after {seconds} seconds!")

# Set the timeout handler
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)

try:
yield
finally:
# Reset the alarm
signal.alarm(0)
50 changes: 50 additions & 0 deletions src/rai_bench/rai_bench/examples/benchmarking_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from rai_bench import (
ManipulationO3DEBenchmarkConfig,
ToolCallingAgentBenchmarkConfig,
test_models,
)

if __name__ == "__main__":
# Define models you want to benchmark
model_names = ["qwen2.5:7b", "llama3.2:3b"]
vendors = ["ollama", "ollama"]

# Define benchmarks that will be used
man_conf = ManipulationO3DEBenchmarkConfig(
o3de_config_path="src/rai_bench/rai_bench/manipulation_o3de/predefined/configs/o3de_config.yaml", # path to your o3de config
levels=[ # define what difficulty of tasks to include in benchmark
"trivial",
],
repeats=1, # how many times to repeat
)
tool_conf = ToolCallingAgentBenchmarkConfig(
extra_tool_calls=5, # how many extra tool calls allowed to still pass
task_types=[ # what types of tasks to include
"basic",
"spatial_reasoning",
"manipulation",
],
repeats=1,
)

out_dir = "src/rai_bench/rai_bench/experiments"
test_models(
model_names=model_names,
vendors=vendors,
benchmark_configs=[man_conf, tool_conf],
out_dir=out_dir,
)
36 changes: 36 additions & 0 deletions src/rai_bench/rai_bench/examples/manipulation_o3de.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

from rai_bench import define_benchmark_logger, parse_manipulation_o3de_benchmark_args
from rai_bench.manipulation_o3de import get_scenarios, run_benchmark

if __name__ == "__main__":
args = parse_manipulation_o3de_benchmark_args()
experiment_dir = Path(args.out_dir)
experiment_dir.mkdir(parents=True, exist_ok=True)
bench_logger = define_benchmark_logger(out_dir=experiment_dir)

# import ready scenarios
scenarios = get_scenarios(logger=bench_logger, levels=args.levels)

run_benchmark(
model_name=args.model_name,
vendor=args.vendor,
out_dir=experiment_dir,
o3de_config_path=args.o3de_config_path,
scenarios=scenarios,
bench_logger=bench_logger,
)
173 changes: 0 additions & 173 deletions src/rai_bench/rai_bench/examples/manipulation_o3de/main.py

This file was deleted.

Loading
Loading