20
20
from rai_bench .tool_calling_agent .interfaces import SubTaskValidationError , Validator
21
21
from rai_bench .tool_calling_agent .validators import (
22
22
NotOrderedCallsValidator ,
23
- OptionalValidator ,
23
+ OneFromManyValidator ,
24
24
OrderedCallsValidator ,
25
25
)
26
26
@@ -680,11 +680,11 @@ def test_validate_reset(self):
680
680
class TestOptionalValidator :
681
681
def test_init_with_empty_subtasks (self ):
682
682
with pytest .raises (ValueError , match = "Validator must have at least 1 subtask" ):
683
- OptionalValidator (subtasks = [])
683
+ OneFromManyValidator (subtasks = [])
684
684
685
685
def test_validate_empty_tool_calls (self ):
686
686
subtasks = [DummySubTask ("task1" )]
687
- validator = OptionalValidator (subtasks = subtasks )
687
+ validator = OneFromManyValidator (subtasks = subtasks )
688
688
689
689
success , remaining = validator .validate (tool_calls = [])
690
690
@@ -709,7 +709,7 @@ def test_validate_successful_first_subtask_matches(self):
709
709
DummySubTask ("task1" , specific_tool = "tool1" ),
710
710
DummySubTask ("task2" , specific_tool = "tool2" ),
711
711
]
712
- validator = OptionalValidator (subtasks = subtasks )
712
+ validator = OneFromManyValidator (subtasks = subtasks )
713
713
tool_calls = [ToolCall (name = "tool1" )]
714
714
715
715
success , remaining = validator .validate (tool_calls = tool_calls )
@@ -737,7 +737,7 @@ def test_validate_successful_second_subtask_matches(self):
737
737
DummySubTask ("task1" , specific_tool = "tool1" ),
738
738
DummySubTask ("task2" , specific_tool = "tool2" ),
739
739
]
740
- validator = OptionalValidator (subtasks = subtasks )
740
+ validator = OneFromManyValidator (subtasks = subtasks )
741
741
tool_calls = [ToolCall (name = "tool2" )]
742
742
743
743
success , remaining = validator .validate (tool_calls = tool_calls )
@@ -765,7 +765,7 @@ def test_validate_successful_with_excess_tool_calls(self):
765
765
DummySubTask ("task1" , specific_tool = "tool1" ),
766
766
DummySubTask ("task2" , specific_tool = "tool2" ),
767
767
]
768
- validator = OptionalValidator (subtasks = subtasks )
768
+ validator = OneFromManyValidator (subtasks = subtasks )
769
769
tool_calls = [
770
770
ToolCall (name = "tool1" ),
771
771
ToolCall (name = "extra_tool" ),
@@ -799,7 +799,7 @@ def test_validate_successful_after_failed_attempts(self):
799
799
DummySubTask ("task1" , specific_tool = "tool1" ),
800
800
DummySubTask ("task2" , specific_tool = "tool2" ),
801
801
]
802
- validator = OptionalValidator (subtasks = subtasks )
802
+ validator = OneFromManyValidator (subtasks = subtasks )
803
803
tool_calls = [
804
804
ToolCall (name = "wrong_tool" ),
805
805
ToolCall (name = "another_wrong" ),
@@ -835,7 +835,7 @@ def test_validate_failure_no_subtask_matches(self):
835
835
DummySubTask ("task1" , specific_tool = "tool1" ),
836
836
DummySubTask ("task2" , specific_tool = "tool2" ),
837
837
]
838
- validator = OptionalValidator (subtasks = subtasks )
838
+ validator = OneFromManyValidator (subtasks = subtasks )
839
839
tool_calls = [
840
840
ToolCall (name = "wrong_tool" ),
841
841
ToolCall (name = "another_wrong" ),
@@ -868,7 +868,7 @@ def test_validate_failure_subtask_validation_error(self):
868
868
DummySubTask ("task1" , outcomes = [False ]),
869
869
DummySubTask ("task2" , outcomes = [False ]),
870
870
]
871
- validator = OptionalValidator (subtasks = subtasks )
871
+ validator = OneFromManyValidator (subtasks = subtasks )
872
872
tool_calls = [ToolCall ()]
873
873
874
874
success , remaining = validator .validate (tool_calls = tool_calls )
@@ -895,7 +895,7 @@ def test_validate_failure_subtask_validation_error(self):
895
895
896
896
def test_validate_single_subtask_success (self ):
897
897
subtasks = [DummySubTask ("task1" )]
898
- validator = OptionalValidator (subtasks = subtasks )
898
+ validator = OneFromManyValidator (subtasks = subtasks )
899
899
tool_calls = [ToolCall ()]
900
900
901
901
success , remaining = validator .validate (tool_calls = tool_calls )
@@ -918,7 +918,7 @@ def test_validate_single_subtask_success(self):
918
918
919
919
def test_validate_single_subtask_failure (self ):
920
920
subtasks = [DummySubTask ("task1" , outcomes = [False ])]
921
- validator = OptionalValidator (subtasks = subtasks )
921
+ validator = OneFromManyValidator (subtasks = subtasks )
922
922
tool_calls = [ToolCall ()]
923
923
924
924
success , remaining = validator .validate (tool_calls = tool_calls )
@@ -947,7 +947,7 @@ def test_validate_many_subtasks_last_one_succeeds(self):
947
947
DummySubTask ("task3" , specific_tool = "tool3" ),
948
948
DummySubTask ("task4" , specific_tool = "tool4" ),
949
949
]
950
- validator = OptionalValidator (subtasks = subtasks )
950
+ validator = OneFromManyValidator (subtasks = subtasks )
951
951
tool_calls = [ToolCall (name = "tool4" )]
952
952
953
953
success , remaining = validator .validate (tool_calls = tool_calls )
@@ -979,7 +979,7 @@ def test_validate_reset(self):
979
979
DummySubTask ("task1" , outcomes = 4 * [False ]),
980
980
DummySubTask ("task2" , outcomes = 4 * [False ]),
981
981
]
982
- validator = OptionalValidator (subtasks = subtasks )
982
+ validator = OneFromManyValidator (subtasks = subtasks )
983
983
tool_calls = [ToolCall (), ToolCall ()]
984
984
985
985
# First validation call
@@ -1013,7 +1013,7 @@ def test_required_calls_property(self):
1013
1013
DummySubTask ("task2" ),
1014
1014
DummySubTask ("task3" ),
1015
1015
]
1016
- validator = OptionalValidator (subtasks = subtasks )
1016
+ validator = OneFromManyValidator (subtasks = subtasks )
1017
1017
1018
1018
# OptionalValidator should only require 1 call
1019
1019
assert validator .required_calls == 1
0 commit comments