Skip to content

Commit 7f35055

Browse files
authored
fix - Remove unmatched tool calls for Claude (All-Hands-AI#7597)
1 parent 933ce47 commit 7f35055

File tree

2 files changed

+206
-1
lines changed

2 files changed

+206
-1
lines changed

openhands/memory/conversation_memory.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Generator
2+
13
from litellm import ModelResponse
24

35
from openhands.core.config.agent_config import AgentConfig
@@ -125,7 +127,7 @@ def process_events(
125127
pending_tool_call_action_messages.pop(response_id)
126128

127129
messages += messages_to_add
128-
130+
messages = list(ConversationMemory._filter_unmatched_tool_calls(messages))
129131
return messages
130132

131133
def process_initial_messages(self, with_caching: bool = False) -> list[Message]:
@@ -592,3 +594,58 @@ def _has_agent_in_earlier_events(
592594
):
593595
return True
594596
return False
597+
598+
@staticmethod
599+
def _filter_unmatched_tool_calls(
600+
messages: list[Message],
601+
) -> Generator[Message, None, None]:
602+
"""Filter out tool calls that don't have matching tool responses and vice versa.
603+
604+
This ensures that every tool_call_id in a tool message has a corresponding tool_calls[].id
605+
in an assistant message, and vice versa. The original list is unmodified, when tool_calls is
606+
updated the message is copied.
607+
608+
This does not remove items with id set to None.
609+
"""
610+
tool_call_ids = {
611+
tool_call.id
612+
for message in messages
613+
if message.tool_calls
614+
for tool_call in message.tool_calls
615+
if message.role == 'assistant' and tool_call.id
616+
}
617+
tool_response_ids = {
618+
message.tool_call_id
619+
for message in messages
620+
if message.role == 'tool' and message.tool_call_id
621+
}
622+
623+
for message in messages:
624+
# Remove tool messages with no matching assistant tool call
625+
if message.role == 'tool' and message.tool_call_id:
626+
if message.tool_call_id in tool_call_ids:
627+
yield message
628+
629+
# Remove assistant tool calls with no matching tool response
630+
elif message.role == 'assistant' and message.tool_calls:
631+
all_tool_calls_match = all(
632+
tool_call.id in tool_response_ids
633+
for tool_call in message.tool_calls
634+
)
635+
if all_tool_calls_match:
636+
yield message
637+
else:
638+
matched_tool_calls = [
639+
tool_call
640+
for tool_call in message.tool_calls
641+
if tool_call.id in tool_response_ids
642+
]
643+
644+
if matched_tool_calls:
645+
# Keep an updated message if there are tools calls left
646+
yield message.model_copy(
647+
update={'tool_calls': matched_tool_calls}
648+
)
649+
else:
650+
# Any other case is kept
651+
yield message

tests/unit/test_conversation_memory.py

+148
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest.mock import MagicMock, Mock
44

55
import pytest
6+
from litellm import ChatCompletionMessageToolCall
67

78
from openhands.controller.state.state import State
89
from openhands.core.config.agent_config import AgentConfig
@@ -1050,3 +1051,150 @@ def test_has_agent_in_earlier_events(conversation_memory):
10501051
conversation_memory._has_agent_in_earlier_events('non_existent', 3, events)
10511052
is False
10521053
)
1054+
1055+
1056+
class TestFilterUnmatchedToolCalls:
1057+
@pytest.fixture
1058+
def processor(self):
1059+
return ConversationMemory()
1060+
1061+
def test_empty_is_unchanged(self):
1062+
assert list(ConversationMemory._filter_unmatched_tool_calls([])) == []
1063+
1064+
def test_no_tool_calls_is_unchanged(self):
1065+
messages = [
1066+
Message(role='user', content=[TextContent(text='Hello')]),
1067+
Message(role='assistant', content=[TextContent(text='Hi there')]),
1068+
Message(role='user', content=[TextContent(text='How are you?')]),
1069+
]
1070+
assert (
1071+
list(ConversationMemory._filter_unmatched_tool_calls(messages)) == messages
1072+
)
1073+
1074+
def test_matched_tool_calls_are_unchanged(self):
1075+
messages = [
1076+
Message(role='user', content=[TextContent(text="What's the weather?")]),
1077+
Message(
1078+
role='assistant',
1079+
content=[],
1080+
tool_calls=[
1081+
ChatCompletionMessageToolCall(
1082+
id='call_1',
1083+
type='function',
1084+
function={'name': 'get_weather', 'arguments': ''},
1085+
)
1086+
],
1087+
),
1088+
Message(
1089+
role='tool',
1090+
tool_call_id='call_1',
1091+
content=[TextContent(text='Sunny, 75°F')],
1092+
),
1093+
Message(role='assistant', content=[TextContent(text="It's sunny today.")]),
1094+
]
1095+
1096+
# All tool calls have matching responses, should remain unchanged
1097+
assert (
1098+
list(ConversationMemory._filter_unmatched_tool_calls(messages)) == messages
1099+
)
1100+
1101+
def test_tool_call_without_response_is_removed(self):
1102+
messages = [
1103+
Message(role='user', content=[TextContent(text='Query')]),
1104+
Message(
1105+
role='tool',
1106+
tool_call_id='missing_call',
1107+
content=[TextContent(text='Response')],
1108+
),
1109+
Message(role='assistant', content=[TextContent(text='Answer')]),
1110+
]
1111+
1112+
expected_after_filter = [
1113+
Message(role='user', content=[TextContent(text='Query')]),
1114+
Message(role='assistant', content=[TextContent(text='Answer')]),
1115+
]
1116+
1117+
result = list(ConversationMemory._filter_unmatched_tool_calls(messages))
1118+
assert result == expected_after_filter
1119+
1120+
def test_tool_response_without_call_is_removed(self):
1121+
messages = [
1122+
Message(role='user', content=[TextContent(text='Query')]),
1123+
Message(
1124+
role='assistant',
1125+
content=[],
1126+
tool_calls=[
1127+
ChatCompletionMessageToolCall(
1128+
id='unmatched_call',
1129+
type='function',
1130+
function={'name': 'some_function', 'arguments': ''},
1131+
)
1132+
],
1133+
),
1134+
Message(role='assistant', content=[TextContent(text='Answer')]),
1135+
]
1136+
1137+
expected_after_filter = [
1138+
Message(role='user', content=[TextContent(text='Query')]),
1139+
Message(role='assistant', content=[TextContent(text='Answer')]),
1140+
]
1141+
1142+
result = list(ConversationMemory._filter_unmatched_tool_calls(messages))
1143+
assert result == expected_after_filter
1144+
1145+
def test_partial_matched_tool_calls_retains_matched(self):
1146+
"""When there are both matched and unmatched tools calls in a message, retain the message and only matched calls"""
1147+
messages = [
1148+
Message(role='user', content=[TextContent(text='Get data')]),
1149+
Message(
1150+
role='assistant',
1151+
content=[],
1152+
tool_calls=[
1153+
ChatCompletionMessageToolCall(
1154+
id='matched_call',
1155+
type='function',
1156+
function={'name': 'function1', 'arguments': ''},
1157+
),
1158+
ChatCompletionMessageToolCall(
1159+
id='unmatched_call',
1160+
type='function',
1161+
function={'name': 'function2', 'arguments': ''},
1162+
),
1163+
],
1164+
),
1165+
Message(
1166+
role='tool',
1167+
tool_call_id='matched_call',
1168+
content=[TextContent(text='Data')],
1169+
),
1170+
Message(role='assistant', content=[TextContent(text='Result')]),
1171+
]
1172+
1173+
expected = [
1174+
Message(role='user', content=[TextContent(text='Get data')]),
1175+
# This message should be modified to only include the matched tool call
1176+
Message(
1177+
role='assistant',
1178+
content=[],
1179+
tool_calls=[
1180+
ChatCompletionMessageToolCall(
1181+
id='matched_call',
1182+
type='function',
1183+
function={'name': 'function1', 'arguments': ''},
1184+
)
1185+
],
1186+
),
1187+
Message(
1188+
role='tool',
1189+
tool_call_id='matched_call',
1190+
content=[TextContent(text='Data')],
1191+
),
1192+
Message(role='assistant', content=[TextContent(text='Result')]),
1193+
]
1194+
1195+
result = list(ConversationMemory._filter_unmatched_tool_calls(messages))
1196+
1197+
# Verify result structure
1198+
assert len(result) == len(expected)
1199+
for i, msg in enumerate(result):
1200+
assert msg == expected[i]

0 commit comments

Comments
 (0)