Skip to content

Commit 9f971c5

Browse files
csmith49Calvin Smithneubig
authored
fix: Context window truncation using CondensationAction (All-Hands-AI#7578)
Co-authored-by: Calvin Smith <[email protected]> Co-authored-by: Graham Neubig <[email protected]>
1 parent c8225b3 commit 9f971c5

17 files changed

+412
-445
lines changed

openhands/agenthub/browsing_agent/browsing_agent.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ def step(self, state: State) -> Action:
150150
last_obs = None
151151
last_action = None
152152

153-
if EVAL_MODE and len(state.history) == 1:
153+
if EVAL_MODE and len(state.view) == 1:
154154
# for webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
155155
# initialize and retrieve the first observation by issuing an noop OP
156156
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
157157
return BrowseInteractiveAction(browser_actions='noop()')
158158

159-
for event in state.history:
159+
for event in state.view:
160160
if isinstance(event, BrowseInteractiveAction):
161161
prev_actions.append(event.browser_actions)
162162
last_action = event

openhands/agenthub/dummy_agent/agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def step(self, state: State) -> Action:
130130

131131
if 'observations' in prev_step and prev_step['observations']:
132132
expected_observations = prev_step['observations']
133-
hist_events = state.history[-len(expected_observations) :]
133+
hist_events = state.view[-len(expected_observations) :]
134134

135135
if len(hist_events) < len(expected_observations):
136136
print(

openhands/agenthub/visualbrowsing_agent/visualbrowsing_agent.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,13 @@ def step(self, state: State) -> Action:
204204
last_action = None
205205
set_of_marks = None # Initialize set_of_marks to None
206206

207-
if len(state.history) == 1:
207+
if len(state.view) == 1:
208208
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
209209
# initialize and retrieve the first observation by issuing an noop OP
210210
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
211211
return BrowseInteractiveAction(browser_actions='noop(1000)')
212212

213-
for event in state.history:
213+
for event in state.view:
214214
if isinstance(event, BrowseInteractiveAction):
215215
prev_actions.append(event)
216216
last_action = event

openhands/controller/agent_controller.py

+7-37
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from openhands.events.action.agent import CondensationAction, RecallAction
5858
from openhands.events.event import Event
5959
from openhands.events.observation import (
60-
AgentCondensationObservation,
6160
AgentDelegateObservation,
6261
AgentStateChangedObservation,
6362
ErrorObservation,
@@ -928,12 +927,6 @@ def _init_history(self) -> None:
928927
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
929928
- Excludes all events between the action and observation
930929
- Includes the delegate action and observation themselves
931-
932-
The history is loaded in two parts if truncation_id is set:
933-
1. First user message from start_id onwards
934-
2. Rest of history from truncation_id to the end
935-
936-
Otherwise loads normally from start_id.
937930
"""
938931
# define range of events to fetch
939932
# delegates start with a start_id and initially won't find any events
@@ -956,29 +949,6 @@ def _init_history(self) -> None:
956949

957950
events: list[Event] = []
958951

959-
# If we have a truncation point, get first user message and then rest of history
960-
if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
961-
# Find first user message from stream
962-
first_user_msg = next(
963-
(
964-
e
965-
for e in self.event_stream.get_events(
966-
start_id=start_id,
967-
end_id=end_id,
968-
reverse=False,
969-
filter_out_type=self.filter_out,
970-
filter_hidden=True,
971-
)
972-
if isinstance(e, MessageAction) and e.source == EventSource.USER
973-
),
974-
None,
975-
)
976-
if first_user_msg:
977-
events.append(first_user_msg)
978-
979-
# the rest of the events are from the truncation point
980-
start_id = self.state.truncation_id
981-
982952
# Get rest of history
983953
events_to_add = list(
984954
self.event_stream.get_events(
@@ -1046,16 +1016,20 @@ def _init_history(self) -> None:
10461016

10471017
def _handle_long_context_error(self) -> None:
10481018
# When context window is exceeded, keep roughly half of agent interactions
1049-
self.state.history = self._apply_conversation_window(self.state.history)
1019+
kept_event_ids = {
1020+
e.id for e in self._apply_conversation_window(self.state.history)
1021+
}
1022+
forgotten_event_ids = {e.id for e in self.state.history} - kept_event_ids
10501023

10511024
# Save the ID of the first event in our truncated history for future reloading
10521025
if self.state.history:
10531026
self.state.start_id = self.state.history[0].id
10541027

10551028
# Add an error event to trigger another step by the agent
10561029
self.event_stream.add_event(
1057-
AgentCondensationObservation(
1058-
content='Trimming prompt to meet context window limitations'
1030+
CondensationAction(
1031+
forgotten_events_start_id=min(forgotten_event_ids),
1032+
forgotten_events_end_id=max(forgotten_event_ids),
10591033
),
10601034
EventSource.AGENT,
10611035
)
@@ -1133,10 +1107,6 @@ def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
11331107
# if it's an action with source == EventSource.AGENT, we're good
11341108
break
11351109

1136-
# Save where to continue from in next reload
1137-
if kept_events:
1138-
self.state.truncation_id = kept_events[0].id
1139-
11401110
# Ensure first user message is included
11411111
if first_user_msg and first_user_msg not in kept_events:
11421112
kept_events = [first_user_msg] + kept_events

openhands/controller/state/state.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from openhands.events.action.agent import AgentFinishAction
1616
from openhands.events.event import Event, EventSource
1717
from openhands.llm.metrics import Metrics
18+
from openhands.memory.view import View
1819
from openhands.storage.files import FileStore
1920
from openhands.storage.locations import get_conversation_agent_state_filename
2021

@@ -96,8 +97,6 @@ class State:
9697
# start_id and end_id track the range of events in history
9798
start_id: int = -1
9899
end_id: int = -1
99-
# truncation_id tracks where to load history after context window truncation
100-
truncation_id: int = -1
101100

102101
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
103102
# NOTE: This will never be used by the controller, but it can be used by different
@@ -170,6 +169,12 @@ def __getstate__(self):
170169
# don't pickle history, it will be restored from the event stream
171170
state = self.__dict__.copy()
172171
state['history'] = []
172+
173+
# Remove any view caching attributes. They'll be rebuilt frmo the
174+
# history after that gets reloaded.
175+
state.pop('_history_checksum', None)
176+
state.pop('_view', None)
177+
173178
return state
174179

175180
def __setstate__(self, state):
@@ -183,7 +188,7 @@ def get_current_user_intent(self) -> tuple[str | None, list[str] | None]:
183188
"""Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet."""
184189
last_user_message = None
185190
last_user_message_image_urls: list[str] | None = []
186-
for event in reversed(self.history):
191+
for event in reversed(self.view):
187192
if isinstance(event, MessageAction) and event.source == 'user':
188193
last_user_message = event.content
189194
last_user_message_image_urls = event.image_urls
@@ -194,13 +199,13 @@ def get_current_user_intent(self) -> tuple[str | None, list[str] | None]:
194199
return last_user_message, last_user_message_image_urls
195200

196201
def get_last_agent_message(self) -> MessageAction | None:
197-
for event in reversed(self.history):
202+
for event in reversed(self.view):
198203
if isinstance(event, MessageAction) and event.source == EventSource.AGENT:
199204
return event
200205
return None
201206

202207
def get_last_user_message(self) -> MessageAction | None:
203-
for event in reversed(self.history):
208+
for event in reversed(self.view):
204209
if isinstance(event, MessageAction) and event.source == EventSource.USER:
205210
return event
206211
return None
@@ -211,7 +216,22 @@ def to_llm_metadata(self, agent_name: str) -> dict:
211216
'trace_version': openhands.__version__,
212217
'tags': [
213218
f'agent:{agent_name}',
214-
f'web_host:{os.environ.get("WEB_HOST", "unspecified")}',
219+
f"web_host:{os.environ.get('WEB_HOST', 'unspecified')}",
215220
f'openhands_version:{openhands.__version__}',
216221
],
217222
}
223+
224+
@property
225+
def view(self) -> View:
226+
# Compute a simple checksum from the history to see if we can re-use any
227+
# cached view.
228+
history_checksum = len(self.history)
229+
old_history_checksum = getattr(self, '_history_checksum', -1)
230+
231+
# If the history has changed, we need to re-create the view and update
232+
# the caching.
233+
if history_checksum != old_history_checksum:
234+
self._history_checksum = history_checksum
235+
self._view = View.from_events(self.history)
236+
237+
return self._view

openhands/memory/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from openhands.memory.condenser import Condenser
2-
3-
__all__ = ['Condenser']

openhands/memory/condenser/condenser.py

+6-73
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22

33
from abc import ABC, abstractmethod
44
from contextlib import contextmanager
5-
from typing import Any, overload
5+
from typing import Any
66

77
from pydantic import BaseModel
88

99
from openhands.controller.state.state import State
1010
from openhands.core.config.condenser_config import CondenserConfig
1111
from openhands.events.action.agent import CondensationAction
12-
from openhands.events.event import Event
13-
from openhands.events.observation.agent import AgentCondensationObservation
12+
from openhands.memory.view import View
1413

1514
CONDENSER_METADATA_KEY = 'condenser_meta'
1615
"""Key identifying where metadata is stored in a `State` object's `extra_data` field."""
@@ -34,69 +33,6 @@ def get_condensation_metadata(state: State) -> list[dict[str, Any]]:
3433
"""Registry of condenser configurations to their corresponding condenser classes."""
3534

3635

37-
class View(BaseModel):
38-
"""Linearly ordered view of events.
39-
40-
Produced by a condenser to indicate the included events are ready to process as LLM input.
41-
"""
42-
43-
events: list[Event]
44-
45-
def __len__(self) -> int:
46-
return len(self.events)
47-
48-
def __iter__(self):
49-
return iter(self.events)
50-
51-
# To preserve list-like indexing, we ideally support slicing and position-based indexing.
52-
# The only challenge with that is switching the return type based on the input type -- we
53-
# can mark the different signatures for MyPy with `@overload` decorators.
54-
55-
@overload
56-
def __getitem__(self, key: slice) -> list[Event]: ...
57-
58-
@overload
59-
def __getitem__(self, key: int) -> Event: ...
60-
61-
def __getitem__(self, key: int | slice) -> Event | list[Event]:
62-
if isinstance(key, slice):
63-
start, stop, step = key.indices(len(self))
64-
return [self[i] for i in range(start, stop, step)]
65-
elif isinstance(key, int):
66-
return self.events[key]
67-
else:
68-
raise ValueError(f'Invalid key type: {type(key)}')
69-
70-
@staticmethod
71-
def from_events(events: list[Event]) -> View:
72-
"""Create a view from a list of events, respecting the semantics of any condensation events."""
73-
forgotten_event_ids: set[int] = set()
74-
for event in events:
75-
if isinstance(event, CondensationAction):
76-
forgotten_event_ids.update(event.forgotten)
77-
78-
kept_events = [event for event in events if event.id not in forgotten_event_ids]
79-
80-
# If we have a summary, insert it at the specified offset.
81-
summary: str | None = None
82-
summary_offset: int | None = None
83-
84-
# The relevant summary is always in the last condensation event (i.e., the most recent one).
85-
for event in reversed(events):
86-
if isinstance(event, CondensationAction):
87-
if event.summary is not None and event.summary_offset is not None:
88-
summary = event.summary
89-
summary_offset = event.summary_offset
90-
break
91-
92-
if summary is not None and summary_offset is not None:
93-
kept_events.insert(
94-
summary_offset, AgentCondensationObservation(content=summary)
95-
)
96-
97-
return View(events=kept_events)
98-
99-
10036
class Condensation(BaseModel):
10137
"""Produced by a condenser to indicate the history has been condensed."""
10238

@@ -150,13 +86,13 @@ def metadata_batch(self, state: State):
15086
self.write_metadata(state)
15187

15288
@abstractmethod
153-
def condense(self, events: list[Event]) -> View | Condensation:
89+
def condense(self, View) -> View | Condensation:
15490
"""Condense a sequence of events into a potentially smaller list.
15591
15692
New condenser strategies should override this method to implement their own condensation logic. Call `self.add_metadata` in the implementation to record any relevant per-condensation diagnostic information.
15793
15894
Args:
159-
events: A list of events representing the entire history of the agent.
95+
View: A view of the history containing all events that should be condensed.
16096
16197
Returns:
16298
View | Condensation: A condensed view of the events or an event indicating the history has been condensed.
@@ -165,7 +101,7 @@ def condense(self, events: list[Event]) -> View | Condensation:
165101
def condensed_history(self, state: State) -> View | Condensation:
166102
"""Condense the state's history."""
167103
with self.metadata_batch(state):
168-
return self.condense(state.history)
104+
return self.condense(state.view)
169105

170106
@classmethod
171107
def register_config(cls, configuration_type: type[CondenserConfig]) -> None:
@@ -221,10 +157,7 @@ def should_condense(self, view: View) -> bool:
221157
def get_condensation(self, view: View) -> Condensation:
222158
"""Get the condensation from a view."""
223159

224-
def condense(self, events: list[Event]) -> View | Condensation:
225-
# Convert the state to a view. This might require some condenser-specific logic.
226-
view = View.from_events(events)
227-
160+
def condense(self, view: View) -> View | Condensation:
228161
# If we trigger the condenser-specific condensation threshold, compute and return
229162
# the condensation.
230163
if self.should_condense(view):

openhands/memory/condenser/impl/browser_output_condenser.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ def __init__(self, attention_window: int = 1):
1717
self.attention_window = attention_window
1818
super().__init__()
1919

20-
def condense(self, events: list[Event]) -> View | Condensation:
20+
def condense(self, view: View) -> View | Condensation:
2121
"""Replace the content of browser observations outside of the attention window with a placeholder."""
2222
results: list[Event] = []
2323
cnt: int = 0
24-
for event in reversed(events):
24+
for event in reversed(view):
2525
if (
2626
isinstance(event, BrowserOutputObservation)
2727
and cnt >= self.attention_window

openhands/memory/condenser/impl/no_op_condenser.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
from __future__ import annotations
22

33
from openhands.core.config.condenser_config import NoOpCondenserConfig
4-
from openhands.events.event import Event
54
from openhands.memory.condenser.condenser import Condensation, Condenser, View
65

76

87
class NoOpCondenser(Condenser):
98
"""A condenser that does nothing to the event sequence."""
109

11-
def condense(self, events: list[Event]) -> View | Condensation:
10+
def condense(self, view: View) -> View | Condensation:
1211
"""Returns the list of events unchanged."""
13-
return View(events=events)
12+
return view
1413

1514
@classmethod
1615
def from_config(cls, config: NoOpCondenserConfig) -> NoOpCondenser:

openhands/memory/condenser/impl/observation_masking_condenser.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@ def __init__(self, attention_window: int = 5):
1515

1616
super().__init__()
1717

18-
def condense(self, events: list[Event]) -> View | Condensation:
18+
def condense(self, view: View) -> View | Condensation:
1919
"""Replace the content of observations outside of the attention window with a placeholder."""
2020
results: list[Event] = []
21-
for i, event in enumerate(events):
22-
if (
23-
isinstance(event, Observation)
24-
and i < len(events) - self.attention_window
25-
):
21+
for i, event in enumerate(view):
22+
if isinstance(event, Observation) and i < len(view) - self.attention_window:
2623
results.append(AgentCondensationObservation('<MASKED>'))
2724
else:
2825
results.append(event)

0 commit comments

Comments
 (0)