Skip to content

Commit 5b47a8c

Browse files
committed
fix: updates to rest of examples, modified default args
1 parent d7fc436 commit 5b47a8c

File tree

5 files changed

+27
-19
lines changed

5 files changed

+27
-19
lines changed

src/rai_bench/rai_bench/examples/manipulation_o3de.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from rai_bench import define_benchmark_logger, parse_manipulation_o3de_benchmark_args
1818
from rai_bench.manipulation_o3de import get_scenarios, run_benchmark
19+
from rai_bench.utils import get_llm_for_benchmark
1920

2021
if __name__ == "__main__":
2122
args = parse_manipulation_o3de_benchmark_args()
@@ -26,9 +27,13 @@
2627
# import ready scenarios
2728
scenarios = get_scenarios(logger=bench_logger, levels=args.levels)
2829

29-
run_benchmark(
30+
llm = get_llm_for_benchmark(
3031
model_name=args.model_name,
3132
vendor=args.vendor,
33+
)
34+
35+
run_benchmark(
36+
llm=llm,
3237
out_dir=experiment_dir,
3338
o3de_config_path=args.o3de_config_path,
3439
scenarios=scenarios,

src/rai_bench/rai_bench/examples/tool_calling_agent.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
get_tasks,
2323
run_benchmark,
2424
)
25+
from rai_bench.utils import get_llm_for_benchmark
2526

2627
if __name__ == "__main__":
2728
args = parse_tool_calling_benchmark_args()
@@ -36,9 +37,14 @@
3637
)
3738
for task in tasks:
3839
task.set_logger(bench_logger)
39-
run_benchmark(
40+
41+
llm = get_llm_for_benchmark(
4042
model_name=args.model_name,
4143
vendor=args.vendor,
44+
)
45+
46+
run_benchmark(
47+
llm=llm,
4248
out_dir=args.out_dir,
4349
tasks=tasks,
4450
bench_logger=bench_logger,

src/rai_bench/rai_bench/manipulation_o3de/benchmark.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
ScenarioResult,
5353
)
5454
from rai_bench.results_processing.langfuse_scores_tracing import ScoreTracingHandler
55+
from rai_bench.utils import get_llm_model_name
5556
from rai_sim.o3de.o3de_bridge import (
5657
O3DEngineArmManipulationBridge,
5758
O3DExROS2SimulationConfig,
@@ -422,15 +423,14 @@ def _setup_benchmark_environment(
422423

423424
def run_benchmark(
424425
llm: BaseChatModel,
425-
model_name: str,
426426
out_dir: Path,
427427
o3de_config_path: str,
428428
scenarios: List[Scenario],
429-
experiment_id: uuid.UUID,
430429
bench_logger: logging.Logger,
430+
experiment_id: uuid.UUID = uuid.uuid4(),
431431
):
432432
connector, o3de, benchmark, tools = _setup_benchmark_environment(
433-
o3de_config_path, model_name, scenarios, out_dir, bench_logger
433+
o3de_config_path, get_llm_model_name(llm), scenarios, out_dir, bench_logger
434434
)
435435
try:
436436
for scenario in scenarios:
@@ -459,17 +459,20 @@ def run_benchmark(
459459
def run_benchmark_dual_agent(
460460
multimodal_llm: BaseChatModel,
461461
tool_calling_llm: BaseChatModel,
462-
model_name: str,
463462
out_dir: Path,
464463
scenarios: List[Scenario],
465464
o3de_config_path: str,
466-
experiment_id: uuid.UUID,
467465
bench_logger: logging.Logger,
466+
experiment_id: uuid.UUID = uuid.uuid4(),
468467
m_system_prompt: Optional[str] = None,
469468
tool_system_prompt: Optional[str] = None,
470469
):
471470
connector, o3de, benchmark, tools = _setup_benchmark_environment(
472-
o3de_config_path, model_name, scenarios, out_dir, bench_logger
471+
o3de_config_path,
472+
get_llm_model_name(multimodal_llm),
473+
scenarios,
474+
out_dir,
475+
bench_logger,
473476
)
474477
basic_tool_system_prompt = (
475478
"Based on the conversation call the tools with appropriate arguments"
@@ -489,7 +492,6 @@ def run_benchmark_dual_agent(
489492
else basic_tool_system_prompt
490493
),
491494
logger=bench_logger,
492-
debug=True,
493495
)
494496

495497
benchmark.run_next(agent=agent, experiment_id=experiment_id)

src/rai_bench/rai_bench/test_models.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def test_dual_agents(
123123
tool_calling_llm=tool_llm,
124124
m_system_prompt=m_system_prompt,
125125
tool_system_prompt=tool_system_prompt,
126-
model_name=get_llm_model_name(m_llm),
127126
out_dir=Path(curr_out_dir),
128127
tasks=tool_calling_tasks,
129128
experiment_id=experiment_id,
@@ -137,7 +136,6 @@ def test_dual_agents(
137136
manipulation_o3de.run_benchmark_dual_agent(
138137
multimodal_llm=m_llm,
139138
tool_calling_llm=tool_llm,
140-
model_name=m_llm.get_name(),
141139
out_dir=Path(curr_out_dir),
142140
o3de_config_path=bench_conf.o3de_config_path,
143141
scenarios=manipulation_o3de_scenarios,
@@ -195,7 +193,6 @@ def test_models(
195193
)
196194
tool_calling_agent.run_benchmark(
197195
llm=llm,
198-
model_name=model_name,
199196
out_dir=Path(curr_out_dir),
200197
tasks=tool_calling_tasks,
201198
experiment_id=experiment_id,
@@ -210,7 +207,6 @@ def test_models(
210207
)
211208
manipulation_o3de.run_benchmark(
212209
llm=llm,
213-
model_name=model_name,
214210
out_dir=Path(curr_out_dir),
215211
o3de_config_path=bench_conf.o3de_config_path,
216212
scenarios=manipulation_o3de_scenarios,

src/rai_bench/rai_bench/tool_calling_agent/benchmark.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from rai_bench.tool_calling_agent.tasks.spatial import (
4242
SpatialReasoningAgentTask,
4343
)
44+
from rai_bench.utils import get_llm_model_name
4445

4546

4647
class ToolCallingAgentBenchmark(BaseBenchmark):
@@ -207,16 +208,15 @@ def compute_and_save_summary(self):
207208

208209
def run_benchmark(
209210
llm: BaseChatModel,
210-
model_name: str,
211211
out_dir: Path,
212212
tasks: List[Task],
213-
experiment_id: uuid.UUID,
214213
bench_logger: logging.Logger,
214+
experiment_id: uuid.UUID = uuid.uuid4(),
215215
):
216216
benchmark = ToolCallingAgentBenchmark(
217217
tasks=tasks,
218218
logger=bench_logger,
219-
model_name=model_name,
219+
model_name=get_llm_model_name(llm),
220220
results_dir=out_dir,
221221
)
222222

@@ -237,18 +237,17 @@ def run_benchmark(
237237
def run_benchmark_dual_agent(
238238
multimodal_llm: BaseChatModel,
239239
tool_calling_llm: BaseChatModel,
240-
model_name: str,
241240
out_dir: Path,
242241
tasks: List[Task],
243-
experiment_id: uuid.UUID,
244242
bench_logger: logging.Logger,
243+
experiment_id: uuid.UUID = uuid.uuid4(),
245244
m_system_prompt: Optional[str] = None,
246245
tool_system_prompt: Optional[str] = None,
247246
):
248247
benchmark = ToolCallingAgentBenchmark(
249248
tasks=tasks,
250249
logger=bench_logger,
251-
model_name=model_name,
250+
model_name=get_llm_model_name(multimodal_llm),
252251
results_dir=out_dir,
253252
)
254253

0 commit comments

Comments
 (0)