Skip to content

Commit 5bb931e

Browse files
Kaushikdkrikhanutobitegexingyaowwenyst
authored
Add prompt caching (Sonnet, Haiku only) (#3411)
* Add prompt caching * remove anthropic-version from extra_headers * change supports_prompt_caching method to attribute * change caching strat and log cache statistics * add reminder as a new message to fix caching * fix unit test * append reminder to the end of the last message content * move token logs to post completion function * fix unit test failure * fix reminder and prompt caching * unit tests for prompt caching * add test * clean up tests * separate reminder, use latest two messages * fix tests --------- Co-authored-by: tobitege <[email protected]> Co-authored-by: Xingyao Wang <[email protected]> Co-authored-by: Engel Nyst <[email protected]>
1 parent e72dc96 commit 5bb931e

File tree

4 files changed

+300
-29
lines changed

4 files changed

+300
-29
lines changed

agenthub/codeact_agent/codeact_agent.py

+36-25
Original file line numberDiff line numberDiff line change
@@ -172,26 +172,44 @@ def step(self, state: State) -> Action:
172172
# prepare what we want to send to the LLM
173173
messages = self._get_messages(state)
174174

175-
response = self.llm.completion(
176-
messages=[message.model_dump() for message in messages],
177-
stop=[
175+
params = {
176+
'messages': [message.model_dump() for message in messages],
177+
'stop': [
178178
'</execute_ipython>',
179179
'</execute_bash>',
180180
'</execute_browse>',
181181
],
182-
temperature=0.0,
183-
)
182+
'temperature': 0.0,
183+
}
184+
185+
if self.llm.supports_prompt_caching:
186+
params['extra_headers'] = {
187+
'anthropic-beta': 'prompt-caching-2024-07-31',
188+
}
189+
190+
response = self.llm.completion(**params)
191+
184192
return self.action_parser.parse(response)
185193

186194
def _get_messages(self, state: State) -> list[Message]:
187195
messages: list[Message] = [
188196
Message(
189197
role='system',
190-
content=[TextContent(text=self.prompt_manager.system_message)],
198+
content=[
199+
TextContent(
200+
text=self.prompt_manager.system_message,
201+
cache_prompt=self.llm.supports_prompt_caching, # Cache system prompt
202+
)
203+
],
191204
),
192205
Message(
193206
role='user',
194-
content=[TextContent(text=self.prompt_manager.initial_user_message)],
207+
content=[
208+
TextContent(
209+
text=self.prompt_manager.initial_user_message,
210+
cache_prompt=self.llm.supports_prompt_caching, # if the user asks the same query,
211+
)
212+
],
195213
),
196214
]
197215

@@ -214,6 +232,16 @@ def _get_messages(self, state: State) -> list[Message]:
214232
else:
215233
messages.append(message)
216234

235+
# Add caching to the last 2 user messages
236+
if self.llm.supports_prompt_caching:
237+
user_turns_processed = 0
238+
for message in reversed(messages):
239+
if message.role == 'user' and user_turns_processed < 2:
240+
message.content[
241+
-1
242+
].cache_prompt = True # Last item inside the message content
243+
user_turns_processed += 1
244+
217245
# the latest user message is important:
218246
# we want to remind the agent of the environment constraints
219247
latest_user_message = next(
@@ -225,25 +253,8 @@ def _get_messages(self, state: State) -> list[Message]:
225253
),
226254
None,
227255
)
228-
229-
# Get the last user text inside content
230256
if latest_user_message:
231-
latest_user_message_text = next(
232-
(
233-
t
234-
for t in reversed(latest_user_message.content)
235-
if isinstance(t, TextContent)
236-
)
237-
)
238-
# add a reminder to the prompt
239257
reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with <finish></finish>.'
240-
241-
if latest_user_message_text:
242-
latest_user_message_text.text = (
243-
latest_user_message_text.text + reminder_text
244-
)
245-
else:
246-
latest_user_message_text = TextContent(text=reminder_text)
247-
latest_user_message.content.append(latest_user_message_text)
258+
latest_user_message.content.append(TextContent(text=reminder_text))
248259

249260
return messages

openhands/core/message.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class ContentType(Enum):
1111

1212
class Content(BaseModel):
1313
type: ContentType
14+
cache_prompt: bool = False
1415

1516
@model_serializer
1617
def serialize_model(self):
@@ -23,7 +24,13 @@ class TextContent(Content):
2324

2425
@model_serializer
2526
def serialize_model(self):
26-
return {'type': self.type.value, 'text': self.text}
27+
data: dict[str, str | dict[str, str]] = {
28+
'type': self.type.value,
29+
'text': self.text,
30+
}
31+
if self.cache_prompt:
32+
data['cache_control'] = {'type': 'ephemeral'}
33+
return data
2734

2835

2936
class ImageContent(Content):
@@ -35,6 +42,8 @@ def serialize_model(self):
3542
images: list[dict[str, str | dict[str, str]]] = []
3643
for url in self.image_urls:
3744
images.append({'type': self.type.value, 'image_url': {'url': url}})
45+
if self.cache_prompt and images:
46+
images[-1]['cache_control'] = {'type': 'ephemeral'}
3847
return images
3948

4049

openhands/llm/llm.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535

3636
message_separator = '\n\n----------\n\n'
3737

38+
cache_prompting_supported_models = [
39+
'claude-3-5-sonnet-20240620',
40+
'claude-3-haiku-20240307',
41+
]
42+
3843

3944
class LLM:
4045
"""The LLM class represents a Language Model instance.
@@ -58,6 +63,9 @@ def __init__(
5863
self.config = copy.deepcopy(config)
5964
self.metrics = metrics if metrics is not None else Metrics()
6065
self.cost_metric_supported = True
66+
self.supports_prompt_caching = (
67+
self.config.model in cache_prompting_supported_models
68+
)
6169

6270
# Set up config attributes with default values to prevent AttributeError
6371
LLMConfig.set_missing_attributes(self.config)
@@ -184,6 +192,7 @@ def wrapper(*args, **kwargs):
184192

185193
# log the response
186194
message_back = resp['choices'][0]['message']['content']
195+
187196
llm_response_logger.debug(message_back)
188197

189198
# post-process to log costs
@@ -421,19 +430,51 @@ def async_streaming_completion(self):
421430
def supports_vision(self):
422431
return litellm.supports_vision(self.config.model)
423432

424-
def _post_completion(self, response: str) -> None:
433+
def _post_completion(self, response) -> None:
425434
"""Post-process the completion response."""
426435
try:
427436
cur_cost = self.completion_cost(response)
428437
except Exception:
429438
cur_cost = 0
439+
440+
stats = ''
430441
if self.cost_metric_supported:
431-
logger.info(
432-
'Cost: %.2f USD | Accumulated Cost: %.2f USD',
442+
stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
433443
cur_cost,
434444
self.metrics.accumulated_cost,
435445
)
436446

447+
usage = response.get('usage')
448+
449+
if usage:
450+
input_tokens = usage.get('prompt_tokens')
451+
output_tokens = usage.get('completion_tokens')
452+
453+
if input_tokens:
454+
stats += 'Input tokens: ' + str(input_tokens) + '\n'
455+
456+
if output_tokens:
457+
stats += 'Output tokens: ' + str(output_tokens) + '\n'
458+
459+
model_extra = usage.get('model_extra', {})
460+
461+
cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
462+
if cache_creation_input_tokens:
463+
stats += (
464+
'Input tokens (cache write): '
465+
+ str(cache_creation_input_tokens)
466+
+ '\n'
467+
)
468+
469+
cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
470+
if cache_read_input_tokens:
471+
stats += (
472+
'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
473+
)
474+
475+
if stats:
476+
logger.info(stats)
477+
437478
def get_token_count(self, messages):
438479
"""Get the number of tokens in a list of messages.
439480

0 commit comments

Comments
 (0)