|
3 | 3 | from unittest.mock import MagicMock, Mock
|
4 | 4 |
|
5 | 5 | import pytest
|
| 6 | +from litellm import ChatCompletionMessageToolCall |
6 | 7 |
|
7 | 8 | from openhands.controller.state.state import State
|
8 | 9 | from openhands.core.config.agent_config import AgentConfig
|
@@ -1050,3 +1051,150 @@ def test_has_agent_in_earlier_events(conversation_memory):
|
1050 | 1051 | conversation_memory._has_agent_in_earlier_events('non_existent', 3, events)
|
1051 | 1052 | is False
|
1052 | 1053 | )
|
| 1054 | + |
| 1055 | + |
| 1056 | +class TestFilterUnmatchedToolCalls: |
| 1057 | + @pytest.fixture |
| 1058 | + def processor(self): |
| 1059 | + return ConversationMemory() |
| 1060 | + |
| 1061 | + def test_empty_is_unchanged(self): |
| 1062 | + assert list(ConversationMemory._filter_unmatched_tool_calls([])) == [] |
| 1063 | + |
| 1064 | + def test_no_tool_calls_is_unchanged(self): |
| 1065 | + messages = [ |
| 1066 | + Message(role='user', content=[TextContent(text='Hello')]), |
| 1067 | + Message(role='assistant', content=[TextContent(text='Hi there')]), |
| 1068 | + Message(role='user', content=[TextContent(text='How are you?')]), |
| 1069 | + ] |
| 1070 | + assert ( |
| 1071 | + list(ConversationMemory._filter_unmatched_tool_calls(messages)) == messages |
| 1072 | + ) |
| 1073 | + |
| 1074 | + def test_matched_tool_calls_are_unchanged(self): |
| 1075 | + messages = [ |
| 1076 | + Message(role='user', content=[TextContent(text="What's the weather?")]), |
| 1077 | + Message( |
| 1078 | + role='assistant', |
| 1079 | + content=[], |
| 1080 | + tool_calls=[ |
| 1081 | + ChatCompletionMessageToolCall( |
| 1082 | + id='call_1', |
| 1083 | + type='function', |
| 1084 | + function={'name': 'get_weather', 'arguments': ''}, |
| 1085 | + ) |
| 1086 | + ], |
| 1087 | + ), |
| 1088 | + Message( |
| 1089 | + role='tool', |
| 1090 | + tool_call_id='call_1', |
| 1091 | + content=[TextContent(text='Sunny, 75°F')], |
| 1092 | + ), |
| 1093 | + Message(role='assistant', content=[TextContent(text="It's sunny today.")]), |
| 1094 | + ] |
| 1095 | + |
| 1096 | + # All tool calls have matching responses, should remain unchanged |
| 1097 | + assert ( |
| 1098 | + list(ConversationMemory._filter_unmatched_tool_calls(messages)) == messages |
| 1099 | + ) |
| 1100 | + |
| 1101 | + def test_tool_call_without_response_is_removed(self): |
| 1102 | + messages = [ |
| 1103 | + Message(role='user', content=[TextContent(text='Query')]), |
| 1104 | + Message( |
| 1105 | + role='tool', |
| 1106 | + tool_call_id='missing_call', |
| 1107 | + content=[TextContent(text='Response')], |
| 1108 | + ), |
| 1109 | + Message(role='assistant', content=[TextContent(text='Answer')]), |
| 1110 | + ] |
| 1111 | + |
| 1112 | + expected_after_filter = [ |
| 1113 | + Message(role='user', content=[TextContent(text='Query')]), |
| 1114 | + Message(role='assistant', content=[TextContent(text='Answer')]), |
| 1115 | + ] |
| 1116 | + |
| 1117 | + result = list(ConversationMemory._filter_unmatched_tool_calls(messages)) |
| 1118 | + assert result == expected_after_filter |
| 1119 | + |
| 1120 | + def test_tool_response_without_call_is_removed(self): |
| 1121 | + messages = [ |
| 1122 | + Message(role='user', content=[TextContent(text='Query')]), |
| 1123 | + Message( |
| 1124 | + role='assistant', |
| 1125 | + content=[], |
| 1126 | + tool_calls=[ |
| 1127 | + ChatCompletionMessageToolCall( |
| 1128 | + id='unmatched_call', |
| 1129 | + type='function', |
| 1130 | + function={'name': 'some_function', 'arguments': ''}, |
| 1131 | + ) |
| 1132 | + ], |
| 1133 | + ), |
| 1134 | + Message(role='assistant', content=[TextContent(text='Answer')]), |
| 1135 | + ] |
| 1136 | + |
| 1137 | + expected_after_filter = [ |
| 1138 | + Message(role='user', content=[TextContent(text='Query')]), |
| 1139 | + Message(role='assistant', content=[TextContent(text='Answer')]), |
| 1140 | + ] |
| 1141 | + |
| 1142 | + result = list(ConversationMemory._filter_unmatched_tool_calls(messages)) |
| 1143 | + assert result == expected_after_filter |
| 1144 | + |
| 1145 | + def test_partial_matched_tool_calls_retains_matched(self): |
| 1146 | + """When there are both matched and unmatched tools calls in a message, retain the message and only matched calls""" |
| 1147 | + messages = [ |
| 1148 | + Message(role='user', content=[TextContent(text='Get data')]), |
| 1149 | + Message( |
| 1150 | + role='assistant', |
| 1151 | + content=[], |
| 1152 | + tool_calls=[ |
| 1153 | + ChatCompletionMessageToolCall( |
| 1154 | + id='matched_call', |
| 1155 | + type='function', |
| 1156 | + function={'name': 'function1', 'arguments': ''}, |
| 1157 | + ), |
| 1158 | + ChatCompletionMessageToolCall( |
| 1159 | + id='unmatched_call', |
| 1160 | + type='function', |
| 1161 | + function={'name': 'function2', 'arguments': ''}, |
| 1162 | + ), |
| 1163 | + ], |
| 1164 | + ), |
| 1165 | + Message( |
| 1166 | + role='tool', |
| 1167 | + tool_call_id='matched_call', |
| 1168 | + content=[TextContent(text='Data')], |
| 1169 | + ), |
| 1170 | + Message(role='assistant', content=[TextContent(text='Result')]), |
| 1171 | + ] |
| 1172 | + |
| 1173 | + expected = [ |
| 1174 | + Message(role='user', content=[TextContent(text='Get data')]), |
| 1175 | + # This message should be modified to only include the matched tool call |
| 1176 | + Message( |
| 1177 | + role='assistant', |
| 1178 | + content=[], |
| 1179 | + tool_calls=[ |
| 1180 | + ChatCompletionMessageToolCall( |
| 1181 | + id='matched_call', |
| 1182 | + type='function', |
| 1183 | + function={'name': 'function1', 'arguments': ''}, |
| 1184 | + ) |
| 1185 | + ], |
| 1186 | + ), |
| 1187 | + Message( |
| 1188 | + role='tool', |
| 1189 | + tool_call_id='matched_call', |
| 1190 | + content=[TextContent(text='Data')], |
| 1191 | + ), |
| 1192 | + Message(role='assistant', content=[TextContent(text='Result')]), |
| 1193 | + ] |
| 1194 | + |
| 1195 | + result = list(ConversationMemory._filter_unmatched_tool_calls(messages)) |
| 1196 | + |
| 1197 | + # Verify result structure |
| 1198 | + assert len(result) == len(expected) |
| 1199 | + for i, msg in enumerate(result): |
| 1200 | + assert msg == expected[i] |
0 commit comments