Skip to content

Commit 97a03fa

Browse files
ColeMurrayenysttobitege
authored
Add Handling of Cache Prompt When Formatting Messages (All-Hands-AI#3773)
* Add Handling of Cache Prompt When Formatting Messages * Fix Value for Cache Control * Fix Value for Cache Control * Update openhands/core/message.py Co-authored-by: Engel Nyst <[email protected]> * Fix lint error * Serialize Messages if Propt Caching Is Enabled * Remove formatting message change --------- Co-authored-by: Engel Nyst <[email protected]> Co-authored-by: tobitege <[email protected]>
1 parent 06ed142 commit 97a03fa

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

openhands/core/message.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,14 @@ def serialize_model(self) -> dict:
7272

7373

7474
def format_messages(
75-
messages: Union[Message, list[Message]], with_images: bool
75+
messages: Union[Message, list[Message]],
76+
with_images: bool,
77+
with_prompt_caching: bool,
7678
) -> list[dict]:
7779
if not isinstance(messages, list):
7880
messages = [messages]
7981

80-
if with_images:
82+
if with_images or with_prompt_caching:
8183
return [message.model_dump() for message in messages]
8284

8385
converted_messages = []
@@ -113,4 +115,5 @@ def format_messages(
113115
'content': content_str,
114116
}
115117
)
118+
116119
return converted_messages

openhands/llm/llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,4 +597,6 @@ def reset(self):
597597
def format_messages_for_llm(
598598
self, messages: Union[Message, list[Message]]
599599
) -> list[dict]:
600-
return format_messages(messages, self.vision_is_active())
600+
return format_messages(
601+
messages, self.vision_is_active(), self.is_caching_prompt_active()
602+
)

tests/integration/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ def mock_user_response(*args, test_name, **kwargs):
185185
def mock_completion(*args, test_name, **kwargs):
186186
global cur_id
187187
messages = kwargs['messages']
188-
plain_messages = format_messages(messages, with_images=False)
188+
plain_messages = format_messages(
189+
messages, with_images=False, with_prompt_caching=False
190+
)
189191
message_str = message_separator.join(msg['content'] for msg in plain_messages)
190192

191193
# this assumes all response_(*).log filenames are in numerical order, starting from one

0 commit comments

Comments
 (0)