4
4
5
5
import json
6
6
import logging
7
+ from base64 import b64encode
7
8
from typing import Any , Callable , Dict , List , Optional , Tuple , Type
8
9
9
10
from haystack import component , default_from_dict
10
11
from haystack .dataclasses import StreamingChunk
11
12
from haystack .lazy_imports import LazyImport
12
13
from haystack .utils import Secret , deserialize_callable , deserialize_secrets_inplace
13
14
14
- from haystack_experimental .dataclasses import ChatMessage , ToolCall
15
+ from haystack_experimental .dataclasses import ChatMessage , ToolCall , ByteStream
15
16
from haystack_experimental .dataclasses .chat_message import ChatRole , ToolCallResult
16
17
from haystack_experimental .dataclasses .tool import Tool , deserialize_tools_inplace
17
18
38
39
# - AnthropicChatGenerator fails with ImportError at init (due to anthropic_integration_import.check()).
39
40
40
41
if anthropic_integration_import .is_successful ():
41
- chatgenerator_base_class : Type [AnthropicChatGeneratorBase ] = AnthropicChatGeneratorBase
42
+ chatgenerator_base_class : Type [AnthropicChatGeneratorBase ] = (
43
+ AnthropicChatGeneratorBase
44
+ )
42
45
else :
43
46
chatgenerator_base_class : Type [object ] = object # type: ignore[no-redef]
44
47
@@ -57,7 +60,9 @@ def _update_anthropic_message_with_tool_call_results(
57
60
58
61
for tool_call_result in tool_call_results :
59
62
if tool_call_result .origin .id is None :
60
- raise ValueError ("`ToolCall` must have a non-null `id` attribute to be used with Anthropic." )
63
+ raise ValueError (
64
+ "`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
65
+ )
61
66
anthropic_msg ["content" ].append (
62
67
{
63
68
"type" : "tool_result" ,
@@ -68,7 +73,9 @@ def _update_anthropic_message_with_tool_call_results(
68
73
)
69
74
70
75
71
- def _convert_tool_calls_to_anthropic_format (tool_calls : List [ToolCall ]) -> List [Dict [str , Any ]]:
76
+ def _convert_tool_calls_to_anthropic_format (
77
+ tool_calls : List [ToolCall ],
78
+ ) -> List [Dict [str , Any ]]:
72
79
"""
73
80
Convert a list of tool calls to the format expected by Anthropic Chat API.
74
81
@@ -78,7 +85,9 @@ def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[
78
85
anthropic_tool_calls = []
79
86
for tc in tool_calls :
80
87
if tc .id is None :
81
- raise ValueError ("`ToolCall` must have a non-null `id` attribute to be used with Anthropic." )
88
+ raise ValueError (
89
+ "`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
90
+ )
82
91
anthropic_tool_calls .append (
83
92
{
84
93
"type" : "tool_use" ,
@@ -90,6 +99,44 @@ def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[
90
99
return anthropic_tool_calls
91
100
92
101
102
+ def _convert_media_to_anthropic_format (media : List [ByteStream ]) -> List [Dict [str , Any ]]:
103
+ """
104
+ Convert a list of media to the format expected by Anthropic Chat API.
105
+
106
+ :param media: The list of ByteStreams to convert.
107
+ :return: A list of dictionaries in the format expected by Anthropic API.
108
+ """
109
+ anthropic_media = []
110
+ for item in media :
111
+ if item .type == "image" :
112
+ anthropic_media .append (
113
+ {
114
+ "type" : "image" ,
115
+ "source" : {
116
+ "type" : "base64" ,
117
+ "media_type" : item .mime_type ,
118
+ "data" : b64encode (item .data ).decode ("utf-8" ),
119
+ },
120
+ }
121
+ )
122
+ elif item .type == "application" and item .subtype == "pdf" :
123
+ anthropic_media .append (
124
+ {
125
+ "type" : "document" ,
126
+ "source" : {
127
+ "type" : "base64" ,
128
+ "media_type" : item .mime_type ,
129
+ "data" : b64encode (item .data ).decode ("utf-8" ),
130
+ },
131
+ }
132
+ )
133
+ else :
134
+ raise ValueError (
135
+ f"Unsupported media type '{ item .mime_type } ' for Anthropic completions."
136
+ )
137
+ return anthropic_media
138
+
139
+
93
140
def _convert_messages_to_anthropic_format (
94
141
messages : List [ChatMessage ],
95
142
) -> Tuple [List [Dict [str , Any ]], List [Dict [str , Any ]]]:
@@ -119,10 +166,17 @@ def _convert_messages_to_anthropic_format(
119
166
120
167
anthropic_msg : Dict [str , Any ] = {"role" : message ._role .value , "content" : []}
121
168
122
- if message .texts and message .texts [0 ]:
123
- anthropic_msg ["content" ].append ({"type" : "text" , "text" : message .texts [0 ]})
169
+ if message .texts :
170
+ for item in message .texts :
171
+ anthropic_msg ["content" ].append ({"type" : "text" , "text" : item })
172
+ if message .media :
173
+ anthropic_msg ["content" ] += _convert_media_to_anthropic_format (
174
+ message .media
175
+ )
124
176
if message .tool_calls :
125
- anthropic_msg ["content" ] += _convert_tool_calls_to_anthropic_format (message .tool_calls )
177
+ anthropic_msg ["content" ] += _convert_tool_calls_to_anthropic_format (
178
+ message .tool_calls
179
+ )
126
180
127
181
if message .tool_call_results :
128
182
results = message .tool_call_results .copy ()
@@ -136,7 +190,8 @@ def _convert_messages_to_anthropic_format(
136
190
137
191
if not anthropic_msg ["content" ]:
138
192
raise ValueError (
139
- "A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
193
+ "A `ChatMessage` must contain at least one `TextContent`, `MediaContent`, "
194
+ "`ToolCall`, or `ToolCallResult`."
140
195
)
141
196
142
197
anthropic_non_system_messages .append (anthropic_msg )
@@ -250,7 +305,9 @@ def to_dict(self) -> Dict[str, Any]:
250
305
The serialized component as a dictionary.
251
306
"""
252
307
serialized = super (AnthropicChatGenerator , self ).to_dict ()
253
- serialized ["init_parameters" ]["tools" ] = [tool .to_dict () for tool in self .tools ] if self .tools else None
308
+ serialized ["init_parameters" ]["tools" ] = (
309
+ [tool .to_dict () for tool in self .tools ] if self .tools else None
310
+ )
254
311
return serialized
255
312
256
313
@classmethod
@@ -267,11 +324,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "AnthropicChatGenerator":
267
324
init_params = data .get ("init_parameters" , {})
268
325
serialized_callback_handler = init_params .get ("streaming_callback" )
269
326
if serialized_callback_handler :
270
- data ["init_parameters" ]["streaming_callback" ] = deserialize_callable (serialized_callback_handler )
327
+ data ["init_parameters" ]["streaming_callback" ] = deserialize_callable (
328
+ serialized_callback_handler
329
+ )
271
330
272
331
return default_from_dict (cls , data )
273
332
274
- def _convert_chat_completion_to_chat_message (self , anthropic_response : Any ) -> ChatMessage :
333
+ def _convert_chat_completion_to_chat_message (
334
+ self , anthropic_response : Any
335
+ ) -> ChatMessage :
275
336
"""
276
337
Converts the response from the Anthropic API to a ChatMessage.
277
338
"""
@@ -343,15 +404,22 @@ def _convert_streaming_chunks_to_chat_message(
343
404
full_content += delta .get ("text" , "" )
344
405
elif delta .get ("type" ) == "input_json_delta" and current_tool_call :
345
406
current_tool_call ["arguments" ] += delta .get ("partial_json" , "" )
346
- elif chunk_type == "message_delta" : # noqa: SIM102 (prefer nested if statement here for readability)
347
- if chunk .meta .get ("delta" , {}).get ("stop_reason" ) == "tool_use" and current_tool_call :
407
+ elif (
408
+ chunk_type == "message_delta"
409
+ ): # noqa: SIM102 (prefer nested if statement here for readability)
410
+ if (
411
+ chunk .meta .get ("delta" , {}).get ("stop_reason" ) == "tool_use"
412
+ and current_tool_call
413
+ ):
348
414
try :
349
415
# arguments is a string, convert to json
350
416
tool_calls .append (
351
417
ToolCall (
352
418
id = current_tool_call .get ("id" ),
353
419
tool_name = str (current_tool_call .get ("name" )),
354
- arguments = json .loads (current_tool_call .get ("arguments" , {})),
420
+ arguments = json .loads (
421
+ current_tool_call .get ("arguments" , {})
422
+ ),
355
423
)
356
424
)
357
425
except json .JSONDecodeError :
@@ -370,7 +438,9 @@ def _convert_streaming_chunks_to_chat_message(
370
438
{
371
439
"model" : model ,
372
440
"index" : 0 ,
373
- "finish_reason" : last_chunk_meta .get ("delta" , {}).get ("stop_reason" , None ),
441
+ "finish_reason" : last_chunk_meta .get ("delta" , {}).get (
442
+ "stop_reason" , None
443
+ ),
374
444
"usage" : last_chunk_meta .get ("usage" , {}),
375
445
}
376
446
)
@@ -405,12 +475,16 @@ def run(
405
475
disallowed_params ,
406
476
self .ALLOWED_PARAMS ,
407
477
)
408
- generation_kwargs = {k : v for k , v in generation_kwargs .items () if k in self .ALLOWED_PARAMS }
478
+ generation_kwargs = {
479
+ k : v for k , v in generation_kwargs .items () if k in self .ALLOWED_PARAMS
480
+ }
409
481
tools = tools or self .tools
410
482
if tools :
411
483
_check_duplicate_tool_names (tools )
412
484
413
- system_messages , non_system_messages = _convert_messages_to_anthropic_format (messages )
485
+ system_messages , non_system_messages = _convert_messages_to_anthropic_format (
486
+ messages
487
+ )
414
488
anthropic_tools = (
415
489
[
416
490
{
@@ -447,12 +521,16 @@ def run(
447
521
"content_block_delta" ,
448
522
"message_delta" ,
449
523
]:
450
- streaming_chunk = self ._convert_anthropic_chunk_to_streaming_chunk (chunk )
524
+ streaming_chunk = self ._convert_anthropic_chunk_to_streaming_chunk (
525
+ chunk
526
+ )
451
527
chunks .append (streaming_chunk )
452
528
if streaming_callback :
453
529
streaming_callback (streaming_chunk )
454
530
455
531
completion = self ._convert_streaming_chunks_to_chat_message (chunks , model )
456
532
return {"replies" : [completion ]}
457
533
else :
458
- return {"replies" : [self ._convert_chat_completion_to_chat_message (response )]}
534
+ return {
535
+ "replies" : [self ._convert_chat_completion_to_chat_message (response )]
536
+ }
0 commit comments