Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prompt caching (Sonnet, Haiku only) #3411

Merged
merged 39 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
260d7ea
Add prompt caching
Kaushikdkrikhanu Aug 16, 2024
c757ef3
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 16, 2024
ecb12ad
remove anthropic-version from extra_headers
Kaushikdkrikhanu Aug 16, 2024
0ea7e68
change supports_prompt_caching method to attribute
Kaushikdkrikhanu Aug 16, 2024
e02eae6
Merge branch 'main' of https://github.com/Kaushikdkrikhanu/OpenDevin …
Kaushikdkrikhanu Aug 16, 2024
6b9be67
Merge branch 'main' into add-prompt-caching
tobitege Aug 17, 2024
9fef189
Merge branch 'main' into add-prompt-caching
xingyaoww Aug 18, 2024
abcc8a1
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 18, 2024
ffb6035
change caching strat and log cache statistics
Kaushikdkrikhanu Aug 19, 2024
87eb4a6
add reminder as a new message to fix caching
Kaushikdkrikhanu Aug 19, 2024
ae4f99a
Merge branch 'main' of https://github.com/Kaushikdkrikhanu/OpenDevin …
Kaushikdkrikhanu Aug 19, 2024
5b101a0
'Merge branch 'main' of https://github.com/Kaushikdkrikhanu/OpenDevin…
Kaushikdkrikhanu Aug 19, 2024
417a69e
fix unit test
Kaushikdkrikhanu Aug 19, 2024
aa515c9
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 19, 2024
005602c
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 19, 2024
5034633
Merge branch 'main' into add-prompt-caching
tobitege Aug 20, 2024
5c1317a
append reminder to the end of the last message content
Kaushikdkrikhanu Aug 20, 2024
d08cf8a
move token logs to post completion function
Kaushikdkrikhanu Aug 20, 2024
4441455
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 20, 2024
1bee42f
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 20, 2024
14888f7
fix unit test failure
Kaushikdkrikhanu Aug 20, 2024
2a35687
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 21, 2024
a5d08fa
Merge branch 'main' into add-prompt-caching
tobitege Aug 21, 2024
ae66b5f
Merge branch 'main' into add-prompt-caching
enyst Aug 21, 2024
08004e4
Merge branch 'main' into add-prompt-caching
enyst Aug 21, 2024
749072d
fix reminder and prompt caching
enyst Aug 21, 2024
24d2e66
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 21, 2024
386d14f
unit tests for prompt caching
enyst Aug 22, 2024
2553ed2
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 22, 2024
732541c
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 23, 2024
cb2e4cc
add test
enyst Aug 23, 2024
90cf091
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 23, 2024
5e17027
clean up tests
enyst Aug 24, 2024
6d8e2a8
Merge branch 'add-prompt-caching' of github.com:Kaushikdkrikhanu/Open…
enyst Aug 24, 2024
8045111
separate reminder, use latest two messages
enyst Aug 24, 2024
fa516f7
Merge branch 'main' of github.com:All-Hands-AI/OpenHands into add-pro…
enyst Aug 24, 2024
1273d58
fix tests
enyst Aug 24, 2024
0f95bc8
Merge branch 'main' into add-prompt-caching
Kaushikdkrikhanu Aug 26, 2024
5667eca
Merge branch 'main' into add-prompt-caching
tobitege Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 36 additions & 25 deletions agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,26 +172,44 @@ def step(self, state: State) -> Action:
# prepare what we want to send to the LLM
messages = self._get_messages(state)

response = self.llm.completion(
messages=[message.model_dump() for message in messages],
stop=[
params = {
'messages': [message.model_dump() for message in messages],
'stop': [
'</execute_ipython>',
'</execute_bash>',
'</execute_browse>',
],
temperature=0.0,
)
'temperature': 0.0,
}

if self.llm.supports_prompt_caching:
params['extra_headers'] = {
'anthropic-beta': 'prompt-caching-2024-07-31',
}

response = self.llm.completion(**params)

return self.action_parser.parse(response)

def _get_messages(self, state: State) -> list[Message]:
messages: list[Message] = [
Message(
role='system',
content=[TextContent(text=self.prompt_manager.system_message)],
content=[
TextContent(
text=self.prompt_manager.system_message,
cache_prompt=self.llm.supports_prompt_caching, # Cache system prompt
)
],
),
Message(
role='user',
content=[TextContent(text=self.prompt_manager.initial_user_message)],
content=[
TextContent(
text=self.prompt_manager.initial_user_message,
cache_prompt=self.llm.supports_prompt_caching, # if the user asks the same query,
)
],
),
]

Expand All @@ -214,6 +232,16 @@ def _get_messages(self, state: State) -> list[Message]:
else:
messages.append(message)

# Add caching to the last 2 user messages
if self.llm.supports_prompt_caching:
user_turns_processed = 0
for message in reversed(messages):
if message.role == 'user' and user_turns_processed < 2:
message.content[
-1
].cache_prompt = True # Last item inside the message content
user_turns_processed += 1

# the latest user message is important:
# we want to remind the agent of the environment constraints
latest_user_message = next(
Expand All @@ -225,25 +253,8 @@ def _get_messages(self, state: State) -> list[Message]:
),
None,
)

# Get the last user text inside content
if latest_user_message:
latest_user_message_text = next(
(
t
for t in reversed(latest_user_message.content)
if isinstance(t, TextContent)
)
)
# add a reminder to the prompt
reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with <finish></finish>.'

if latest_user_message_text:
latest_user_message_text.text = (
latest_user_message_text.text + reminder_text
)
else:
latest_user_message_text = TextContent(text=reminder_text)
latest_user_message.content.append(latest_user_message_text)
latest_user_message.content.append(TextContent(text=reminder_text))

return messages
11 changes: 10 additions & 1 deletion openhands/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ContentType(Enum):

class Content(BaseModel):
type: ContentType
cache_prompt: bool = False

@model_serializer
def serialize_model(self):
Expand All @@ -23,7 +24,13 @@ class TextContent(Content):

@model_serializer
def serialize_model(self):
return {'type': self.type.value, 'text': self.text}
data: dict[str, str | dict[str, str]] = {
'type': self.type.value,
'text': self.text,
}
if self.cache_prompt:
data['cache_control'] = {'type': 'ephemeral'}
return data


class ImageContent(Content):
Expand All @@ -35,6 +42,8 @@ def serialize_model(self):
images: list[dict[str, str | dict[str, str]]] = []
for url in self.image_urls:
images.append({'type': self.type.value, 'image_url': {'url': url}})
if self.cache_prompt and images:
images[-1]['cache_control'] = {'type': 'ephemeral'}
return images


Expand Down
47 changes: 44 additions & 3 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@

message_separator = '\n\n----------\n\n'

cache_prompting_supported_models = [
'claude-3-5-sonnet-20240620',
'claude-3-haiku-20240307',
]


class LLM:
"""The LLM class represents a Language Model instance.
Expand All @@ -58,6 +63,9 @@ def __init__(
self.config = copy.deepcopy(config)
self.metrics = metrics if metrics is not None else Metrics()
self.cost_metric_supported = True
self.supports_prompt_caching = (
self.config.model in cache_prompting_supported_models
)

# Set up config attributes with default values to prevent AttributeError
LLMConfig.set_missing_attributes(self.config)
Expand Down Expand Up @@ -184,6 +192,7 @@ def wrapper(*args, **kwargs):

# log the response
message_back = resp['choices'][0]['message']['content']

llm_response_logger.debug(message_back)

# post-process to log costs
Expand Down Expand Up @@ -421,19 +430,51 @@ def async_streaming_completion(self):
def supports_vision(self):
return litellm.supports_vision(self.config.model)

def _post_completion(self, response: str) -> None:
def _post_completion(self, response) -> None:
"""Post-process the completion response."""
try:
cur_cost = self.completion_cost(response)
except Exception:
cur_cost = 0

stats = ''
if self.cost_metric_supported:
logger.info(
'Cost: %.2f USD | Accumulated Cost: %.2f USD',
stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
cur_cost,
self.metrics.accumulated_cost,
)

usage = response.get('usage')

if usage:
input_tokens = usage.get('prompt_tokens')
output_tokens = usage.get('completion_tokens')

if input_tokens:
stats += 'Input tokens: ' + str(input_tokens) + '\n'

if output_tokens:
stats += 'Output tokens: ' + str(output_tokens) + '\n'

model_extra = usage.get('model_extra', {})

cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
if cache_creation_input_tokens:
stats += (
'Input tokens (cache write): '
+ str(cache_creation_input_tokens)
+ '\n'
)

cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
if cache_read_input_tokens:
stats += (
'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
)

if stats:
logger.info(stats)

def get_token_count(self, messages):
"""Get the number of tokens in a list of messages.

Expand Down
Loading