Skip to content

Commit 6031934

Browse files
committed
feat: add multimodal support for ChatMessage
1 parent e1aebb8 commit 6031934

File tree

10 files changed

+605
-90
lines changed

10 files changed

+605
-90
lines changed

haystack_experimental/components/generators/anthropic/chat/chat_generator.py

+98-20
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44

55
import json
66
import logging
7+
from base64 import b64encode
78
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
89

910
from haystack import component, default_from_dict
1011
from haystack.dataclasses import StreamingChunk
1112
from haystack.lazy_imports import LazyImport
1213
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace
1314

14-
from haystack_experimental.dataclasses import ChatMessage, ToolCall
15+
from haystack_experimental.dataclasses import ChatMessage, ToolCall, ByteStream
1516
from haystack_experimental.dataclasses.chat_message import ChatRole, ToolCallResult
1617
from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace
1718

@@ -38,7 +39,9 @@
3839
# - AnthropicChatGenerator fails with ImportError at init (due to anthropic_integration_import.check()).
3940

4041
if anthropic_integration_import.is_successful():
41-
chatgenerator_base_class: Type[AnthropicChatGeneratorBase] = AnthropicChatGeneratorBase
42+
chatgenerator_base_class: Type[AnthropicChatGeneratorBase] = (
43+
AnthropicChatGeneratorBase
44+
)
4245
else:
4346
chatgenerator_base_class: Type[object] = object # type: ignore[no-redef]
4447

@@ -57,7 +60,9 @@ def _update_anthropic_message_with_tool_call_results(
5760

5861
for tool_call_result in tool_call_results:
5962
if tool_call_result.origin.id is None:
60-
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
63+
raise ValueError(
64+
"`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
65+
)
6166
anthropic_msg["content"].append(
6267
{
6368
"type": "tool_result",
@@ -68,7 +73,9 @@ def _update_anthropic_message_with_tool_call_results(
6873
)
6974

7075

71-
def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[Dict[str, Any]]:
76+
def _convert_tool_calls_to_anthropic_format(
77+
tool_calls: List[ToolCall],
78+
) -> List[Dict[str, Any]]:
7279
"""
7380
Convert a list of tool calls to the format expected by Anthropic Chat API.
7481
@@ -78,7 +85,9 @@ def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[
7885
anthropic_tool_calls = []
7986
for tc in tool_calls:
8087
if tc.id is None:
81-
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
88+
raise ValueError(
89+
"`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
90+
)
8291
anthropic_tool_calls.append(
8392
{
8493
"type": "tool_use",
@@ -90,6 +99,44 @@ def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[
9099
return anthropic_tool_calls
91100

92101

102+
def _convert_media_to_anthropic_format(media: List[ByteStream]) -> List[Dict[str, Any]]:
103+
"""
104+
Convert a list of media to the format expected by Anthropic Chat API.
105+
106+
:param media: The list of ByteStreams to convert.
107+
:return: A list of dictionaries in the format expected by Anthropic API.
108+
"""
109+
anthropic_media = []
110+
for item in media:
111+
if item.type == "image":
112+
anthropic_media.append(
113+
{
114+
"type": "image",
115+
"source": {
116+
"type": "base64",
117+
"media_type": item.mime_type,
118+
"data": b64encode(item.data).decode("utf-8"),
119+
},
120+
}
121+
)
122+
elif item.type == "application" and item.subtype == "pdf":
123+
anthropic_media.append(
124+
{
125+
"type": "document",
126+
"source": {
127+
"type": "base64",
128+
"media_type": item.mime_type,
129+
"data": b64encode(item.data).decode("utf-8"),
130+
},
131+
}
132+
)
133+
else:
134+
raise ValueError(
135+
f"Unsupported media type '{item.mime_type}' for Anthropic completions."
136+
)
137+
return anthropic_media
138+
139+
93140
def _convert_messages_to_anthropic_format(
94141
messages: List[ChatMessage],
95142
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
@@ -119,10 +166,17 @@ def _convert_messages_to_anthropic_format(
119166

120167
anthropic_msg: Dict[str, Any] = {"role": message._role.value, "content": []}
121168

122-
if message.texts and message.texts[0]:
123-
anthropic_msg["content"].append({"type": "text", "text": message.texts[0]})
169+
if message.texts:
170+
for item in message.texts:
171+
anthropic_msg["content"].append({"type": "text", "text": item})
172+
if message.media:
173+
anthropic_msg["content"] += _convert_media_to_anthropic_format(
174+
message.media
175+
)
124176
if message.tool_calls:
125-
anthropic_msg["content"] += _convert_tool_calls_to_anthropic_format(message.tool_calls)
177+
anthropic_msg["content"] += _convert_tool_calls_to_anthropic_format(
178+
message.tool_calls
179+
)
126180

127181
if message.tool_call_results:
128182
results = message.tool_call_results.copy()
@@ -136,7 +190,8 @@ def _convert_messages_to_anthropic_format(
136190

137191
if not anthropic_msg["content"]:
138192
raise ValueError(
139-
"A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
193+
"A `ChatMessage` must contain at least one `TextContent`, `MediaContent`, "
194+
"`ToolCall`, or `ToolCallResult`."
140195
)
141196

142197
anthropic_non_system_messages.append(anthropic_msg)
@@ -250,7 +305,9 @@ def to_dict(self) -> Dict[str, Any]:
250305
The serialized component as a dictionary.
251306
"""
252307
serialized = super(AnthropicChatGenerator, self).to_dict()
253-
serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None
308+
serialized["init_parameters"]["tools"] = (
309+
[tool.to_dict() for tool in self.tools] if self.tools else None
310+
)
254311
return serialized
255312

256313
@classmethod
@@ -267,11 +324,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "AnthropicChatGenerator":
267324
init_params = data.get("init_parameters", {})
268325
serialized_callback_handler = init_params.get("streaming_callback")
269326
if serialized_callback_handler:
270-
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
327+
data["init_parameters"]["streaming_callback"] = deserialize_callable(
328+
serialized_callback_handler
329+
)
271330

272331
return default_from_dict(cls, data)
273332

274-
def _convert_chat_completion_to_chat_message(self, anthropic_response: Any) -> ChatMessage:
333+
def _convert_chat_completion_to_chat_message(
334+
self, anthropic_response: Any
335+
) -> ChatMessage:
275336
"""
276337
Converts the response from the Anthropic API to a ChatMessage.
277338
"""
@@ -343,15 +404,22 @@ def _convert_streaming_chunks_to_chat_message(
343404
full_content += delta.get("text", "")
344405
elif delta.get("type") == "input_json_delta" and current_tool_call:
345406
current_tool_call["arguments"] += delta.get("partial_json", "")
346-
elif chunk_type == "message_delta": # noqa: SIM102 (prefer nested if statement here for readability)
347-
if chunk.meta.get("delta", {}).get("stop_reason") == "tool_use" and current_tool_call:
407+
elif (
408+
chunk_type == "message_delta"
409+
): # noqa: SIM102 (prefer nested if statement here for readability)
410+
if (
411+
chunk.meta.get("delta", {}).get("stop_reason") == "tool_use"
412+
and current_tool_call
413+
):
348414
try:
349415
# arguments is a string, convert to json
350416
tool_calls.append(
351417
ToolCall(
352418
id=current_tool_call.get("id"),
353419
tool_name=str(current_tool_call.get("name")),
354-
arguments=json.loads(current_tool_call.get("arguments", {})),
420+
arguments=json.loads(
421+
current_tool_call.get("arguments", {})
422+
),
355423
)
356424
)
357425
except json.JSONDecodeError:
@@ -370,7 +438,9 @@ def _convert_streaming_chunks_to_chat_message(
370438
{
371439
"model": model,
372440
"index": 0,
373-
"finish_reason": last_chunk_meta.get("delta", {}).get("stop_reason", None),
441+
"finish_reason": last_chunk_meta.get("delta", {}).get(
442+
"stop_reason", None
443+
),
374444
"usage": last_chunk_meta.get("usage", {}),
375445
}
376446
)
@@ -405,12 +475,16 @@ def run(
405475
disallowed_params,
406476
self.ALLOWED_PARAMS,
407477
)
408-
generation_kwargs = {k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS}
478+
generation_kwargs = {
479+
k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS
480+
}
409481
tools = tools or self.tools
410482
if tools:
411483
_check_duplicate_tool_names(tools)
412484

413-
system_messages, non_system_messages = _convert_messages_to_anthropic_format(messages)
485+
system_messages, non_system_messages = _convert_messages_to_anthropic_format(
486+
messages
487+
)
414488
anthropic_tools = (
415489
[
416490
{
@@ -447,12 +521,16 @@ def run(
447521
"content_block_delta",
448522
"message_delta",
449523
]:
450-
streaming_chunk = self._convert_anthropic_chunk_to_streaming_chunk(chunk)
524+
streaming_chunk = self._convert_anthropic_chunk_to_streaming_chunk(
525+
chunk
526+
)
451527
chunks.append(streaming_chunk)
452528
if streaming_callback:
453529
streaming_callback(streaming_chunk)
454530

455531
completion = self._convert_streaming_chunks_to_chat_message(chunks, model)
456532
return {"replies": [completion]}
457533
else:
458-
return {"replies": [self._convert_chat_completion_to_chat_message(response)]}
534+
return {
535+
"replies": [self._convert_chat_completion_to_chat_message(response)]
536+
}

haystack_experimental/components/generators/chat/openai.py

+77-41
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import json
66
import os
7+
from base64 import b64encode
78
from typing import Any, Dict, List, Optional, Union
89

910
from haystack import component, default_from_dict, default_to_dict, logging
@@ -19,7 +20,14 @@
1920
from openai.types.chat.chat_completion import Choice
2021
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
2122

22-
from haystack_experimental.dataclasses import ChatMessage, Tool, ToolCall
23+
from haystack_experimental.dataclasses import (
24+
ChatMessage,
25+
Tool,
26+
ToolCall,
27+
TextContent,
28+
ChatRole,
29+
MediaContent,
30+
)
2331
from haystack_experimental.dataclasses.streaming_chunk import (
2432
AsyncStreamingCallbackT,
2533
StreamingCallbackT,
@@ -34,53 +42,81 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, Any]:
3442
"""
3543
Convert a message to the format expected by OpenAI's Chat API.
3644
"""
37-
text_contents = message.texts
38-
tool_calls = message.tool_calls
39-
tool_call_results = message.tool_call_results
40-
41-
if not text_contents and not tool_calls and not tool_call_results:
42-
raise ValueError(
43-
"A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
44-
)
45-
elif len(text_contents) + len(tool_call_results) > 1:
45+
openai_msg: Dict[str, Any] = {"role": message.role.value}
46+
if len(message) == 0:
4647
raise ValueError(
47-
"A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`."
48+
"ChatMessage must contain at least one `TextContent`, "
49+
"`MediaContent`, `ToolCall`, or `ToolCallResult`."
4850
)
49-
50-
openai_msg: Dict[str, Any] = {"role": message._role.value}
51-
52-
if tool_call_results:
53-
result = tool_call_results[0]
54-
if result.origin.id is None:
51+
if len(message) == 1 and isinstance(message.content[0], TextContent):
52+
openai_msg["content"] = message.content[0].text
53+
elif message.tool_call_result:
54+
# Tool call results should only be included for ChatRole.TOOL messages
55+
# and should not include any other content
56+
if message.role != ChatRole.TOOL:
57+
raise ValueError(
58+
"Tool call results should only be included for tool messages."
59+
)
60+
if len(message) > 1:
61+
raise ValueError(
62+
"Tool call results should not be included with other content."
63+
)
64+
if message.tool_call_result.origin.id is None:
5565
raise ValueError(
5666
"`ToolCall` must have a non-null `id` attribute to be used with OpenAI."
5767
)
58-
openai_msg["content"] = result.result
59-
openai_msg["tool_call_id"] = result.origin.id
60-
# OpenAI does not provide a way to communicate errors in tool invocations, so we ignore the error field
61-
return openai_msg
62-
63-
if text_contents:
64-
openai_msg["content"] = text_contents[0]
65-
if tool_calls:
66-
openai_tool_calls = []
67-
for tc in tool_calls:
68-
if tc.id is None:
68+
openai_msg["content"] = message.tool_call_result.result
69+
openai_msg["tool_call_id"] = message.tool_call_result.origin.id
70+
else:
71+
openai_msg["content"] = []
72+
for item in message.content:
73+
if isinstance(item, TextContent):
74+
openai_msg["content"].append({"type": "text", "text": item.text})
75+
elif isinstance(item, MediaContent):
76+
match item.media.type:
77+
case "image":
78+
base64_data = b64encode(item.media.data).decode("utf-8")
79+
url = f"data:{item.media.mime_type};base64,{base64_data}"
80+
openai_msg["content"].append(
81+
{
82+
"type": "image_url",
83+
"image_url": {
84+
"url": url,
85+
"detail": item.media.meta.get("detail", "auto"),
86+
},
87+
}
88+
)
89+
case _:
90+
raise ValueError(
91+
f"Unsupported media type '{item.media.mime_type}' for OpenAI completions."
92+
)
93+
elif isinstance(item, ToolCall):
94+
if message.role != ChatRole.ASSISTANT:
95+
raise ValueError(
96+
"Tool calls should only be included for assistant messages."
97+
)
98+
if item.id is None:
99+
raise ValueError(
100+
"`ToolCall` must have a non-null `id` attribute to be used with OpenAI."
101+
)
102+
openai_msg.setdefault("tool_calls", []).append(
103+
{
104+
"id": item.id,
105+
"type": "function",
106+
"function": {
107+
"name": item.tool_name,
108+
"arguments": json.dumps(item.arguments, ensure_ascii=False),
109+
},
110+
}
111+
)
112+
else:
69113
raise ValueError(
70-
"`ToolCall` must have a non-null `id` attribute to be used with OpenAI."
114+
f"Unsupported content type '{type(item).__name__}' for OpenAI completions."
71115
)
72-
openai_tool_calls.append(
73-
{
74-
"id": tc.id,
75-
"type": "function",
76-
# We disable ensure_ascii so special chars like emojis are not converted
77-
"function": {
78-
"name": tc.tool_name,
79-
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
80-
},
81-
}
82-
)
83-
openai_msg["tool_calls"] = openai_tool_calls
116+
117+
if message.name:
118+
openai_msg["name"] = message.name
119+
84120
return openai_msg
85121

86122

0 commit comments

Comments
 (0)