Skip to content

Commit e596a89

Browse files
committed
feat!(ContentToolResult): add a new model_format parameter
1 parent 96b8fec commit e596a89

File tree

8 files changed

+192
-22
lines changed

8 files changed

+192
-22
lines changed

chatlas/_anthropic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,12 +488,15 @@ def _as_content_block(content: Content) -> "ContentBlockParam":
488488
"input": content.arguments,
489489
}
490490
elif isinstance(content, ContentToolResult):
491-
return {
491+
res: ToolResultBlockParam = {
492492
"type": "tool_result",
493493
"tool_use_id": content.id,
494-
"content": content.get_final_value(),
495494
"is_error": content.error is not None,
496495
}
496+
# Anthropic supports non-text contents like ImageBlockParam
497+
res["content"] = content.get_formatted_value() # type: ignore
498+
return res
499+
497500
raise ValueError(f"Unknown content type: {type(content)}")
498501

499502
@staticmethod

chatlas/_content.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,15 @@ class ContentToolResult(Content):
226226
----------
227227
value
228228
The value returned by the tool/function (to be sent to the model).
229+
model_format
230+
The format used for sending the value to the model. The default,
231+
`"auto"`, first attempts to format the value as a JSON string. If that
232+
fails, it gets converted to a string via `str()`. To force
233+
`json.dumps()` or `str()`, set to `"json"` or `"str"`. Finally,
234+
`"as_is"` is useful for doing your own formatting and/or passing a
235+
non-string value (e.g., a list or dict) straight to the model.
236+
Non-string values are useful for tools that return images or other
237+
'known' non-text content types.
229238
error
230239
An exception that occurred during the tool request. If this is set, the
231240
error message sent to the model and the value is ignored.
@@ -239,6 +248,7 @@ class ContentToolResult(Content):
239248

240249
# public
241250
value: Any
251+
model_format: Literal["auto", "json", "str", "as_is"] = "auto"
242252
error: Optional[Exception] = None
243253
extra: Any = None
244254

@@ -266,22 +276,11 @@ def arguments(self):
266276
)
267277
return self.request.arguments
268278

269-
def _get_value(self, pretty: bool = False) -> str:
270-
if self.error:
271-
return f"Tool call failed with error: '{self.error}'"
272-
if not pretty:
273-
return str(self.value)
274-
try:
275-
json_val = json.loads(self.value) # type: ignore
276-
return pformat(json_val, indent=2, sort_dicts=False)
277-
except: # noqa
278-
return str(self.value)
279-
280279
# Primarily used for `echo="all"`...
281280
def __str__(self):
282281
prefix = "✅ tool result" if not self.error else "❌ tool error"
283282
comment = f"# {prefix} ({self.id})"
284-
value = self._get_value(pretty=True)
283+
value = self._get_display_value()
285284
return f"""```python\n{comment}\n{value}\n```"""
286285

287286
# ... and for displaying in the notebook
@@ -295,9 +294,52 @@ def __repr__(self, indent: int = 0):
295294
res += f" error='{self.error}'"
296295
return res + ">"
297296

298-
# The actual value to send to the model
299-
def get_final_value(self) -> str:
300-
return self._get_value()
297+
# Format the value for display purposes
298+
def _get_display_value(self) -> object:
299+
if self.error:
300+
return f"Tool call failed with error: '{self.error}'"
301+
302+
val = self.value
303+
304+
# If value is already a dict or list, format it directly
305+
if isinstance(val, (dict, list)):
306+
return pformat(val, indent=2, sort_dicts=False)
307+
308+
# For string values, try to parse as JSON
309+
if isinstance(val, str):
310+
try:
311+
json_val = json.loads(val)
312+
return pformat(json_val, indent=2, sort_dicts=False)
313+
except json.JSONDecodeError:
314+
# Not valid JSON, return as string
315+
return val
316+
317+
return val
318+
319+
def get_model_value(self) -> object:
320+
"Get the actual value sent to the model."
321+
322+
if self.error:
323+
return f"Tool call failed with error: '{self.error}'"
324+
325+
val, mode = (self.value, self.model_format)
326+
327+
if isinstance(val, str):
328+
return val
329+
330+
if mode == "auto":
331+
try:
332+
return json.dumps(val)
333+
except Exception:
334+
return str(val)
335+
elif mode == "json":
336+
return json.dumps(val)
337+
elif mode == "str":
338+
return str(val)
339+
elif mode == "as_is":
340+
return val
341+
else:
342+
raise ValueError(f"Unknown format mode: {mode}")
301343

302344
def tagify(self) -> "TagChild":
303345
"""
@@ -317,7 +359,7 @@ def tagify(self) -> "TagChild":
317359
header = f"❌ Failed to call tool <code>{self.name}</code>"
318360

319361
args = self._arguments_str()
320-
content = self._get_value(pretty=True)
362+
content = self._get_display_value()
321363

322364
return HTML(
323365
textwrap.dedent(f"""

chatlas/_google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def _as_part_type(self, content: Content) -> "Part":
432432
if content.error:
433433
resp = {"error": content.error}
434434
else:
435-
resp = {"result": str(content.value)}
435+
resp = {"result": content.get_model_value()}
436436
return Part(
437437
# TODO: seems function response parts might need role='tool'???
438438
# https://github.com/googleapis/python-genai/blame/c8cfef85c/README.md#L344

chatlas/_openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,8 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
499499
elif isinstance(x, ContentToolResult):
500500
tool_results.append(
501501
ChatCompletionToolMessageParam(
502-
# TODO: a tool could return an image!?!
503-
content=x.get_final_value(),
502+
# Currently, OpenAI only allows for text content in tool results
503+
content=cast(str, x.get_formatted_value()),
504504
tool_call_id=x.id,
505505
role="tool",
506506
)

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,8 @@ def assert_pdf_local(chat_fun: ChatFun):
256256
wait=wait_exponential(min=1, max=60),
257257
reraise=True,
258258
)
259+
260+
261+
@pytest.fixture
262+
def test_images_dir():
263+
return Path(__file__).parent / "images"

tests/images/dice.png

219 KB
Loading

tests/test_provider_anthropic.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import base64
2+
13
import pytest
2-
from chatlas import ChatAnthropic
4+
from chatlas import ChatAnthropic, ContentToolResult
35

46
from .conftest import (
57
assert_data_extraction,
@@ -96,3 +98,33 @@ def test_anthropic_empty_response():
9698
chat.chat("Respond with only two blank lines")
9799
resp = chat.chat("What's 1+1? Just give me the number")
98100
assert "2" == str(resp).strip()
101+
102+
103+
def test_anthropic_image_tool(test_images_dir):
104+
def get_picture():
105+
"Returns an image"
106+
# Local copy of https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png
107+
with open(test_images_dir / "dice.png", "rb") as image:
108+
bytez = image.read()
109+
res = [
110+
{
111+
"type": "image",
112+
"source": {
113+
"type": "base64",
114+
"media_type": "image/png",
115+
"data": base64.b64encode(bytez).decode("utf-8"),
116+
},
117+
}
118+
]
119+
return ContentToolResult(value=res, model_format="as_is")
120+
121+
chat = ChatAnthropic()
122+
chat.register_tool(get_picture)
123+
124+
res = chat.chat(
125+
"You have a tool called 'get_picture' available to you. "
126+
"When called, it returns an image. "
127+
"Tell me what you see in the image."
128+
)
129+
130+
assert "dice" in res.get_content()

tests/test_provider_snowflake.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import pytest
2+
3+
from chatlas import ChatSnowflake
4+
5+
from .conftest import assert_data_extraction, assert_turns_existing, assert_turns_system
6+
7+
8+
def test_openai_simple_request():
9+
chat = ChatSnowflake(
10+
connection_name="posit",
11+
system_prompt="Be as terse as possible; no punctuation",
12+
)
13+
chat.chat("What is 1 + 1?")
14+
turn = chat.get_last_turn()
15+
assert turn is not None
16+
17+
# No token / finish_reason info available
18+
# assert turn.tokens is not None
19+
# assert len(turn.tokens) == 2
20+
# assert turn.tokens[0] == 27
21+
# # Not testing turn.tokens[1] because it's not deterministic. Typically 1 or 2.
22+
# assert turn.finish_reason == "stop"
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_openai_simple_streaming_request():
27+
chat = ChatSnowflake(
28+
connection_name="posit",
29+
system_prompt="Be as terse as possible; no punctuation",
30+
)
31+
res = []
32+
async for x in await chat.stream_async("What is 1 + 1?"):
33+
res.append(x)
34+
assert "2" in "".join(res)
35+
turn = chat.get_last_turn()
36+
37+
# No token / finish_reason info available
38+
# assert turn is not None
39+
# assert turn.finish_reason == "stop"
40+
41+
42+
def test_openai_respects_turns_interface():
43+
def chat_fun(**kwargs):
44+
return ChatSnowflake(connection_name="posit", **kwargs)
45+
46+
assert_turns_system(chat_fun)
47+
assert_turns_existing(chat_fun)
48+
49+
50+
#
51+
# def test_openai_tool_variations():
52+
# def chat_fun(**kwargs):
53+
# return ChatSnowflake(connection_name="posit", **kwargs)
54+
#
55+
# assert_tools_simple(chat_fun)
56+
# assert_tools_simple_stream_content(chat_fun)
57+
# assert_tools_parallel(chat_fun)
58+
# assert_tools_sequential(chat_fun, total_calls=6)
59+
#
60+
#
61+
# @pytest.mark.asyncio
62+
# async def test_openai_tool_variations_async():
63+
# def chat_fun(**kwargs):
64+
# return ChatSnowflake(connection_name="posit", **kwargs)
65+
#
66+
# await assert_tools_async(chat_fun)
67+
68+
69+
def test_data_extraction():
70+
def chat_fun():
71+
return ChatSnowflake(connection_name="posit")
72+
73+
assert_data_extraction(chat_fun)
74+
75+
76+
# def test_openai_images():
77+
# def chat_fun(**kwargs):
78+
# return ChatSnowflake(connection_name="posit", **kwargs)
79+
#
80+
# assert_images_inline(chat_fun)
81+
# assert_images_remote(chat_fun)
82+
#
83+
#
84+
# def test_openai_pdf():
85+
# def chat_fun(**kwargs):
86+
# return ChatSnowflake(connection_name="posit", **kwargs)
87+
#
88+
# assert_pdf_local(chat_fun)

0 commit comments

Comments
 (0)