Skip to content

Commit cfdf1ee

Browse files
committed
fix: applied suggestions
1 parent f7d36f9 commit cfdf1ee

File tree

4 files changed

+24
-24
lines changed

4 files changed

+24
-24
lines changed

src/rai_bench/rai_bench/tool_calling_agent/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def run_next(self, agent: CompiledStateGraph, experiment_id: uuid.UUID) -> None:
8787
)
8888
callbacks = self.score_tracing_handler.get_callbacks()
8989
run_id = uuid.uuid4()
90-
# NOTE (jmatejcz) reccustion limit calculated as all_nodes_num -> one pass though whole node
90+
# NOTE (jmatejcz) recursion limit calculated as all_nodes_num -> one pass though whole node
9191
# plus (task.max_tool_calls_number-1 because the first pass is already added in)
9292
# times number of nodes - 2 because we dont cout start and end node
9393
# this can be to much for larger graphs that dont use all nodes on extra calls

src/rai_bench/rai_bench/tool_calling_agent/predefined/basic_tasks.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141
from rai_bench.tool_calling_agent.validators import (
4242
NotOrderedCallsValidator,
43-
OptionalValidator,
43+
OneFromManyValidator,
4444
OrderedCallsValidator,
4545
)
4646

@@ -409,7 +409,7 @@
409409
get_pointcloud_ord_val = OrderedCallsValidator(subtasks=[receive_pointcloud_subtask])
410410
get_robot_desc_ord_val = OrderedCallsValidator(subtasks=[receive_robot_desc_subtask])
411411

412-
set_param_val = OptionalValidator(
412+
set_param_val = OneFromManyValidator(
413413
subtasks=[set_robot_state_params_subtask, set_robot_state_params_atomically_subtask]
414414
)
415415
services_ord_val = OrderedCallsValidator(subtasks=[get_services_subtask])
@@ -420,24 +420,24 @@
420420
)
421421
spawn_entity_val = OrderedCallsValidator(subtasks=[spawn_entity_subtask])
422422

423-
set_grounded_sam_opt_val_1 = OptionalValidator(
423+
set_grounded_sam_opt_val_1 = OneFromManyValidator(
424424
subtasks=[set_grounded_sam_subtask_1, set_grounded_sam_atomically_subtask_1]
425425
)
426-
set_grounded_dino_opt_val_1 = OptionalValidator(
426+
set_grounded_dino_opt_val_1 = OneFromManyValidator(
427427
subtasks=[set_grounded_dino_subtask_1, set_grounding_dino_atomically_subtask_1]
428428
)
429-
set_o3de_fps_opt_val_1 = OptionalValidator(
429+
set_o3de_fps_opt_val_1 = OneFromManyValidator(
430430
subtasks=[set_o3de_fps_subtask_1, set_o3de_fps_atomically_subtask_1]
431431
)
432432

433433

434-
set_grounded_sam_opt_val_2 = OptionalValidator(
434+
set_grounded_sam_opt_val_2 = OneFromManyValidator(
435435
subtasks=[set_grounded_sam_subtask_2, set_grounded_sam_atomically_subtask_2]
436436
)
437-
set_grounded_dino_opt_val_2 = OptionalValidator(
437+
set_grounded_dino_opt_val_2 = OneFromManyValidator(
438438
subtasks=[set_grounding_dino_subtask_2, set_grounding_dino_atomically_subtask_2]
439439
)
440-
set_o3de_fps_opt_val_2 = OptionalValidator(
440+
set_o3de_fps_opt_val_2 = OneFromManyValidator(
441441
subtasks=[set_o3de_fps_subtask_2, set_o3de_fps_atomically_subtask_2]
442442
)
443443

src/rai_bench/rai_bench/tool_calling_agent/validators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def validate(self, tool_calls: List[ToolCall]) -> Tuple[bool, List[ToolCall]]:
146146
return False, []
147147

148148

149-
class OptionalValidator(Validator):
149+
class OneFromManyValidator(Validator):
150150
"""
151151
Validator that passes when any one of the given subtasks passes.
152152
"""

tests/rai_bench/tool_calling_agent/test_validators.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from rai_bench.tool_calling_agent.interfaces import SubTaskValidationError, Validator
2121
from rai_bench.tool_calling_agent.validators import (
2222
NotOrderedCallsValidator,
23-
OptionalValidator,
23+
OneFromManyValidator,
2424
OrderedCallsValidator,
2525
)
2626

@@ -680,11 +680,11 @@ def test_validate_reset(self):
680680
class TestOptionalValidator:
681681
def test_init_with_empty_subtasks(self):
682682
with pytest.raises(ValueError, match="Validator must have at least 1 subtask"):
683-
OptionalValidator(subtasks=[])
683+
OneFromManyValidator(subtasks=[])
684684

685685
def test_validate_empty_tool_calls(self):
686686
subtasks = [DummySubTask("task1")]
687-
validator = OptionalValidator(subtasks=subtasks)
687+
validator = OneFromManyValidator(subtasks=subtasks)
688688

689689
success, remaining = validator.validate(tool_calls=[])
690690

@@ -709,7 +709,7 @@ def test_validate_successful_first_subtask_matches(self):
709709
DummySubTask("task1", specific_tool="tool1"),
710710
DummySubTask("task2", specific_tool="tool2"),
711711
]
712-
validator = OptionalValidator(subtasks=subtasks)
712+
validator = OneFromManyValidator(subtasks=subtasks)
713713
tool_calls = [ToolCall(name="tool1")]
714714

715715
success, remaining = validator.validate(tool_calls=tool_calls)
@@ -737,7 +737,7 @@ def test_validate_successful_second_subtask_matches(self):
737737
DummySubTask("task1", specific_tool="tool1"),
738738
DummySubTask("task2", specific_tool="tool2"),
739739
]
740-
validator = OptionalValidator(subtasks=subtasks)
740+
validator = OneFromManyValidator(subtasks=subtasks)
741741
tool_calls = [ToolCall(name="tool2")]
742742

743743
success, remaining = validator.validate(tool_calls=tool_calls)
@@ -765,7 +765,7 @@ def test_validate_successful_with_excess_tool_calls(self):
765765
DummySubTask("task1", specific_tool="tool1"),
766766
DummySubTask("task2", specific_tool="tool2"),
767767
]
768-
validator = OptionalValidator(subtasks=subtasks)
768+
validator = OneFromManyValidator(subtasks=subtasks)
769769
tool_calls = [
770770
ToolCall(name="tool1"),
771771
ToolCall(name="extra_tool"),
@@ -799,7 +799,7 @@ def test_validate_successful_after_failed_attempts(self):
799799
DummySubTask("task1", specific_tool="tool1"),
800800
DummySubTask("task2", specific_tool="tool2"),
801801
]
802-
validator = OptionalValidator(subtasks=subtasks)
802+
validator = OneFromManyValidator(subtasks=subtasks)
803803
tool_calls = [
804804
ToolCall(name="wrong_tool"),
805805
ToolCall(name="another_wrong"),
@@ -835,7 +835,7 @@ def test_validate_failure_no_subtask_matches(self):
835835
DummySubTask("task1", specific_tool="tool1"),
836836
DummySubTask("task2", specific_tool="tool2"),
837837
]
838-
validator = OptionalValidator(subtasks=subtasks)
838+
validator = OneFromManyValidator(subtasks=subtasks)
839839
tool_calls = [
840840
ToolCall(name="wrong_tool"),
841841
ToolCall(name="another_wrong"),
@@ -868,7 +868,7 @@ def test_validate_failure_subtask_validation_error(self):
868868
DummySubTask("task1", outcomes=[False]),
869869
DummySubTask("task2", outcomes=[False]),
870870
]
871-
validator = OptionalValidator(subtasks=subtasks)
871+
validator = OneFromManyValidator(subtasks=subtasks)
872872
tool_calls = [ToolCall()]
873873

874874
success, remaining = validator.validate(tool_calls=tool_calls)
@@ -895,7 +895,7 @@ def test_validate_failure_subtask_validation_error(self):
895895

896896
def test_validate_single_subtask_success(self):
897897
subtasks = [DummySubTask("task1")]
898-
validator = OptionalValidator(subtasks=subtasks)
898+
validator = OneFromManyValidator(subtasks=subtasks)
899899
tool_calls = [ToolCall()]
900900

901901
success, remaining = validator.validate(tool_calls=tool_calls)
@@ -918,7 +918,7 @@ def test_validate_single_subtask_success(self):
918918

919919
def test_validate_single_subtask_failure(self):
920920
subtasks = [DummySubTask("task1", outcomes=[False])]
921-
validator = OptionalValidator(subtasks=subtasks)
921+
validator = OneFromManyValidator(subtasks=subtasks)
922922
tool_calls = [ToolCall()]
923923

924924
success, remaining = validator.validate(tool_calls=tool_calls)
@@ -947,7 +947,7 @@ def test_validate_many_subtasks_last_one_succeeds(self):
947947
DummySubTask("task3", specific_tool="tool3"),
948948
DummySubTask("task4", specific_tool="tool4"),
949949
]
950-
validator = OptionalValidator(subtasks=subtasks)
950+
validator = OneFromManyValidator(subtasks=subtasks)
951951
tool_calls = [ToolCall(name="tool4")]
952952

953953
success, remaining = validator.validate(tool_calls=tool_calls)
@@ -979,7 +979,7 @@ def test_validate_reset(self):
979979
DummySubTask("task1", outcomes=4 * [False]),
980980
DummySubTask("task2", outcomes=4 * [False]),
981981
]
982-
validator = OptionalValidator(subtasks=subtasks)
982+
validator = OneFromManyValidator(subtasks=subtasks)
983983
tool_calls = [ToolCall(), ToolCall()]
984984

985985
# First validation call
@@ -1013,7 +1013,7 @@ def test_required_calls_property(self):
10131013
DummySubTask("task2"),
10141014
DummySubTask("task3"),
10151015
]
1016-
validator = OptionalValidator(subtasks=subtasks)
1016+
validator = OneFromManyValidator(subtasks=subtasks)
10171017

10181018
# OptionalValidator should only require 1 call
10191019
assert validator.required_calls == 1

0 commit comments

Comments
 (0)