Skip to content

Fix truncation state persistence bug #7212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class AgentController:
ChangeAgentStateAction,
AgentStateChangedObservation,
)
_cached_first_user_message: MessageAction | None = None
_cached_first_user_message: MessageAction | None = None

def __init__(
self,
Expand Down Expand Up @@ -926,6 +928,9 @@ def _init_history(self) -> None:
if self.state.end_id >= 0
else self.event_stream.get_latest_event_id()
)

# We will restore from here
original_start_id = self.state.start_id

# sanity check
if start_id > end_id + 1:
Expand All @@ -941,20 +946,7 @@ def _init_history(self) -> None:
# If we have a truncation point, get first user message and then rest of history
if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
# Find first user message from stream
first_user_msg = next(
(
e
for e in self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=False,
filter_out_type=self.filter_out,
filter_hidden=True,
)
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)
first_user_msg = self._first_user_message(start_id=start_id, end_id=end_id)
if first_user_msg:
events.append(first_user_msg)

Expand Down Expand Up @@ -1023,8 +1015,8 @@ def _init_history(self) -> None:
else:
self.state.history = events

# make sure history is in sync
self.state.start_id = start_id
# from the (original) first user message
self.state.start_id = original_start_id

def _handle_long_context_error(self) -> None:
# When context window is exceeded, keep roughly half of agent interactions
Expand Down Expand Up @@ -1197,24 +1189,45 @@ def _is_awaiting_observation(self):
return result
return False

def _first_user_message(self) -> MessageAction | None:
def _first_user_message(
self, start_id: int = -1, end_id: int = -1
) -> MessageAction | None:
"""Get the first user message for this agent.

For regular agents, this is the first user message from the beginning (start_id=0).
For delegate agents, this is the first user message after the delegate's start_id.

Returns:
MessageAction | None: The first user message, or None if no user message found
The first user message, or None if no user message found (though that should never happen)
"""
# Find the first user message from the appropriate starting point
user_messages = list(self.event_stream.get_events(start_id=self.state.start_id))
if self._cached_first_user_message is not None:
return self._cached_first_user_message

# start_id is typically saved in state as state.start_id
if start_id == -1:
start_id = self.state.start_id if self.state.start_id >= 0 else 0

# end_id is saved in state as state.end_id
if end_id == -1:
end_id = (
self.state.end_id
if self.state.end_id >= 0
else self.event_stream.get_latest_event_id()
)

# Get and return the first user message
return next(
# Find the first user message
self._cached_first_user_message = next(
(
e
for e in user_messages
for e in self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=False,
filter_out_type=self.filter_out,
filter_hidden=True,
)
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)
return self._cached_first_user_message
15 changes: 15 additions & 0 deletions tests/unit/test_agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ async def test_first_user_message_with_identical_content():
"""
Test that _first_user_message correctly identifies the first user message
even when multiple messages have identical content but different IDs.
Also verifies that the result is properly cached.

The issue we're checking is that the comparison (action == self._first_user_message())
should correctly differentiate between messages with the same content but different IDs.
Expand Down Expand Up @@ -1038,6 +1039,20 @@ async def test_first_user_message_with_identical_content():
second_message.id != first_user_message.id
) # This should be False, but may be True if there's a bug

# Verify caching behavior
assert (
controller._cached_first_user_message is not None
) # Cache should be populated
assert (
controller._cached_first_user_message is first_user_message
) # Cache should store the same object

# Mock get_events to verify it's not called again
with patch.object(event_stream, 'get_events') as mock_get_events:
cached_message = controller._first_user_message()
assert cached_message is first_user_message # Should return cached object
mock_get_events.assert_not_called() # Should not call get_events again

await controller.close()


Expand Down
Loading
Loading