Skip to content

Commit b33422b

Browse files
authored
feat!(ContentToolResult): add a new model_format parameter (#87)
* feat!(ContentToolResult): add a new model_format parameter * Update changelog * Use orjson over json * More sophisticated json dumping approach * Improve docstring of ContentToolResult * Address feedback * Always dump to a string * Small tweak to changelog
1 parent 96b8fec commit b33422b

File tree

11 files changed

+160
-48
lines changed

11 files changed

+160
-48
lines changed

CHANGELOG.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
### New features
1313

1414
* Added `ChatDatabricks()`, for chatting with Databrick's [foundation models](https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models). (#82)
15-
* `.stream()` and `.stream_async()` gain a `content` argument. Set this to `"all"` to include `ContentToolRequest` and `ContentToolResponse` instances in the stream. (#75)
16-
* `ContentToolRequest` and `ContentToolResponse` are now exported to `chatlas` namespace. (#75)
17-
* `ContentToolRequest` and `ContentToolResponse` now have `.tagify()` methods, making it so they can render automatically in a Shiny chatbot. (#75)
18-
* `ContentToolResult` instances can be returned from tools. This allows for custom rendering of the tool result. (#75)
15+
* `.stream()` and `.stream_async()` gain a `content` argument. Set this to `"all"` to include `ContentToolResult`/`ContentToolRequest` objects in the stream. (#75)
16+
* `ContentToolResult`/`ContentToolRequest` are now exported to `chatlas` namespace. (#75)
17+
* `ContentToolResult`/`ContentToolRequest` gain a `.tagify()` method so they render sensibly in a Shiny app. (#75)
18+
* A tool can now return a `ContentToolResult`. This is useful for:
19+
* Specifying the format used for sending the tool result to the chat model (`model_format`). (#87)
20+
* Custom rendering of the tool result (by overriding relevant methods in a subclass). (#75)
1921
* `Chat` gains a new `.current_display` property. When a `.chat()` or `.stream()` is currently active, this property returns an object with a `.echo()` method (to echo new content to the display). This is primarily useful for displaying custom content during a tool call. (#79)
2022

2123
### Improvements
@@ -25,11 +27,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2527
* `.extract_data()` is now supported.
2628
* `async` methods are now supported. (#81)
2729
* Fixed an issue with more than one session being active at once. (#83)
28-
* `ChatAnthropic()` no longer choke after receiving an output that consists only of whitespace. (#86)
30+
* `ChatAnthropic()` no longer chokes after receiving an output that consists only of whitespace. (#86)
31+
* `orjson` is now used for JSON loading and dumping. (#87)
2932

3033
### Changes
3134

3235
* The `echo` argument of the `.chat()` method defaults to a new value of `"output"`. As a result, tool requests and results are now echoed by default. To revert to the previous behavior, set `echo="text"`. (#78)
36+
* Tool results are now dumped to JSON by default before being sent to the model. To revert to the previous behavior, have the tool return a `ContentToolResult` with `model_format="str"`. (#87)
3337

3438
### Breaking changes
3539

chatlas/_anthropic.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

33
import base64
4-
import json
54
import warnings
65
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast, overload
76

7+
import orjson
88
from pydantic import BaseModel
99

1010
from ._chat import Chat
@@ -366,8 +366,8 @@ def stream_merge_chunks(self, completion, chunk):
366366
this_content = completion.content[chunk.index]
367367
if this_content.type == "tool_use" and isinstance(this_content.input, str):
368368
try:
369-
this_content.input = json.loads(this_content.input or "{}")
370-
except json.JSONDecodeError as e:
369+
this_content.input = orjson.loads(this_content.input or "{}")
370+
except orjson.JSONDecodeError as e:
371371
raise ValueError(f"Invalid JSON input: {e}")
372372
elif chunk.type == "message_delta":
373373
completion.stop_reason = chunk.delta.stop_reason
@@ -488,12 +488,15 @@ def _as_content_block(content: Content) -> "ContentBlockParam":
488488
"input": content.arguments,
489489
}
490490
elif isinstance(content, ContentToolResult):
491-
return {
491+
res: ToolResultBlockParam = {
492492
"type": "tool_result",
493493
"tool_use_id": content.id,
494-
"content": content.get_final_value(),
495494
"is_error": content.error is not None,
496495
}
496+
# Anthropic supports non-text contents like ImageBlockParam
497+
res["content"] = content.get_model_value() # type: ignore
498+
return res
499+
497500
raise ValueError(f"Unknown content type: {type(content)}")
498501

499502
@staticmethod

chatlas/_auto.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3-
import json
43
import os
54
from typing import Callable, Literal, Optional
65

6+
import orjson
7+
78
from ._anthropic import ChatAnthropic, ChatBedrockAnthropic
89
from ._chat import Chat
910
from ._databricks import ChatDatabricks
@@ -178,7 +179,7 @@ def ChatAuto(
178179

179180
env_kwargs = {}
180181
if env_kwargs_str := os.environ.get("CHATLAS_CHAT_ARGS"):
181-
env_kwargs = json.loads(env_kwargs_str)
182+
env_kwargs = orjson.loads(env_kwargs_str)
182183

183184
kwargs = {**kwargs, **env_kwargs, **base_args}
184185
kwargs = {k: v for k, v in kwargs.items() if v is not None}

chatlas/_content.py

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
import json
43
import textwrap
54
from pprint import pformat
65
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
76

7+
import orjson
88
from pydantic import BaseModel, ConfigDict
99

1010
if TYPE_CHECKING:
@@ -218,27 +218,50 @@ class ContentToolResult(Content):
218218
"""
219219
The result of calling a tool/function
220220
221-
This content type isn't meant to be used directly. Instead, it's
222-
automatically generated by [](`~chatlas.Chat`) when a tool/function is
223-
called (in response to a [](`~chatlas.ContentToolRequest`)).
221+
A content type representing the result of a tool function call. When a model
222+
requests a tool function, [](`~chatlas.Chat`) will create, (optionally)
223+
echo, (optionally) yield, and store this content type in the chat history.
224+
225+
A tool function may also construct an instance of this class and return it.
226+
This is useful for a tool that wishes to customize how the result is handled
227+
(e.g., the format of the value sent to the model).
224228
225229
Parameters
226230
----------
227231
value
228-
The value returned by the tool/function (to be sent to the model).
232+
The return value of the tool/function.
233+
model_format
234+
The format used for sending the value to the model. The default,
235+
`"auto"`, first attempts to format the value as a JSON string. If that
236+
fails, it gets converted to a string via `str()`. To force
237+
`orjson.dumps()` or `str()`, set to `"json"` or `"str"`. Finally,
238+
`"as_is"` is useful for doing your own formatting and/or passing a
239+
non-string value (e.g., a list or dict) straight to the model.
240+
Non-string values are useful for tools that return images or other
241+
'known' non-text content types.
229242
error
230-
An exception that occurred during the tool request. If this is set, the
243+
An exception that occurred while invoking the tool. If this is set, the
231244
error message sent to the model and the value is ignored.
232245
extra
233246
Additional data associated with the tool result that isn't sent to the
234247
model.
235248
request
236249
Not intended to be used directly. It will be set when the
237250
:class:`~chatlas.Chat` invokes the tool.
251+
252+
Note
253+
----
254+
When `model_format` is `"json"` (or `"auto"`), and the value has a
255+
`.to_json()`/`.to_dict()` method, those methods are called to obtain the
256+
JSON representation of the value. This is convenient for classes, like
257+
`pandas.DataFrame`, that have a `.to_json()` method, but don't necessarily
258+
dump to JSON directly. If this happens to not be the desired behavior, set
259+
`model_format="as_is"` return the desired value as-is.
238260
"""
239261

240262
# public
241263
value: Any
264+
model_format: Literal["auto", "json", "str", "as_is"] = "auto"
242265
error: Optional[Exception] = None
243266
extra: Any = None
244267

@@ -266,22 +289,11 @@ def arguments(self):
266289
)
267290
return self.request.arguments
268291

269-
def _get_value(self, pretty: bool = False) -> str:
270-
if self.error:
271-
return f"Tool call failed with error: '{self.error}'"
272-
if not pretty:
273-
return str(self.value)
274-
try:
275-
json_val = json.loads(self.value) # type: ignore
276-
return pformat(json_val, indent=2, sort_dicts=False)
277-
except: # noqa
278-
return str(self.value)
279-
280292
# Primarily used for `echo="all"`...
281293
def __str__(self):
282294
prefix = "✅ tool result" if not self.error else "❌ tool error"
283295
comment = f"# {prefix} ({self.id})"
284-
value = self._get_value(pretty=True)
296+
value = self._get_display_value()
285297
return f"""```python\n{comment}\n{value}\n```"""
286298

287299
# ... and for displaying in the notebook
@@ -295,9 +307,62 @@ def __repr__(self, indent: int = 0):
295307
res += f" error='{self.error}'"
296308
return res + ">"
297309

298-
# The actual value to send to the model
299-
def get_final_value(self) -> str:
300-
return self._get_value()
310+
# Format the value for display purposes
311+
def _get_display_value(self) -> object:
312+
if self.error:
313+
return f"Tool call failed with error: '{self.error}'"
314+
315+
val = self.value
316+
317+
# If value is already a dict or list, format it directly
318+
if isinstance(val, (dict, list)):
319+
return pformat(val, indent=2, sort_dicts=False)
320+
321+
# For string values, try to parse as JSON
322+
if isinstance(val, str):
323+
try:
324+
json_val = orjson.loads(val)
325+
return pformat(json_val, indent=2, sort_dicts=False)
326+
except orjson.JSONDecodeError:
327+
# Not valid JSON, return as string
328+
return val
329+
330+
return val
331+
332+
def get_model_value(self) -> object:
333+
"Get the actual value sent to the model."
334+
335+
if self.error:
336+
return f"Tool call failed with error: '{self.error}'"
337+
338+
val, mode = (self.value, self.model_format)
339+
340+
if isinstance(val, str):
341+
return val
342+
343+
if mode == "auto":
344+
try:
345+
return self._to_json(val)
346+
except Exception:
347+
return str(val)
348+
elif mode == "json":
349+
return self._to_json(val)
350+
elif mode == "str":
351+
return str(val)
352+
elif mode == "as_is":
353+
return val
354+
else:
355+
raise ValueError(f"Unknown format mode: {mode}")
356+
357+
@staticmethod
358+
def _to_json(value: Any) -> object:
359+
if hasattr(value, "to_json") and callable(value.to_json):
360+
return value.to_json()
361+
362+
if hasattr(value, "to_dict") and callable(value.to_dict):
363+
value = value.to_dict()
364+
365+
return orjson.dumps(value).decode("utf-8")
301366

302367
def tagify(self) -> "TagChild":
303368
"""
@@ -317,7 +382,7 @@ def tagify(self) -> "TagChild":
317382
header = f"❌ Failed to call tool <code>{self.name}</code>"
318383

319384
args = self._arguments_str()
320-
content = self._get_value(pretty=True)
385+
content = self._get_display_value()
321386

322387
return HTML(
323388
textwrap.dedent(f"""
@@ -355,7 +420,7 @@ class ContentJson(Content):
355420
content_type: ContentTypeEnum = "json"
356421

357422
def __str__(self):
358-
return json.dumps(self.value, indent=2)
423+
return orjson.dumps(self.value, option=orjson.OPT_INDENT_2).decode("utf-8")
359424

360425
def _repr_markdown_(self):
361426
return f"""```json\n{self.__str__()}\n```"""

chatlas/_google.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
import base64
4-
import json
54
from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
65

6+
import orjson
77
from pydantic import BaseModel
88

99
from ._chat import Chat
@@ -432,7 +432,7 @@ def _as_part_type(self, content: Content) -> "Part":
432432
if content.error:
433433
resp = {"error": content.error}
434434
else:
435-
resp = {"result": str(content.value)}
435+
resp = {"result": content.get_model_value()}
436436
return Part(
437437
# TODO: seems function response parts might need role='tool'???
438438
# https://github.com/googleapis/python-genai/blame/c8cfef85c/README.md#L344
@@ -470,7 +470,7 @@ def _as_turn(
470470
text = part.get("text")
471471
if text:
472472
if has_data_model:
473-
contents.append(ContentJson(value=json.loads(text)))
473+
contents.append(ContentJson(value=orjson.loads(text)))
474474
else:
475475
contents.append(ContentText(text=text))
476476
function_call = part.get("function_call")

chatlas/_ollama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

3-
import json
43
import re
54
import urllib.request
65
from typing import TYPE_CHECKING, Optional
76

7+
import orjson
8+
89
from ._chat import Chat
910
from ._openai import ChatOpenAI
1011
from ._turn import Turn
@@ -121,7 +122,7 @@ def ChatOllama(
121122

122123
def ollama_models(base_url: str) -> list[str]:
123124
res = urllib.request.urlopen(url=f"{base_url}/api/tags")
124-
data = json.loads(res.read())
125+
data = orjson.loads(res.read())
125126
return [re.sub(":latest$", "", x["name"]) for x in data["models"]]
126127

127128

chatlas/_openai.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
import base64
4-
import json
54
from typing import TYPE_CHECKING, Any, Literal, Optional, cast, overload
65

6+
import orjson
77
from pydantic import BaseModel
88

99
from ._chat import Chat
@@ -433,7 +433,7 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
433433
"id": x.id,
434434
"function": {
435435
"name": x.name,
436-
"arguments": json.dumps(x.arguments),
436+
"arguments": orjson.dumps(x.arguments).decode("utf-8"),
437437
},
438438
"type": "function",
439439
}
@@ -499,8 +499,8 @@ def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]:
499499
elif isinstance(x, ContentToolResult):
500500
tool_results.append(
501501
ChatCompletionToolMessageParam(
502-
# TODO: a tool could return an image!?!
503-
content=x.get_final_value(),
502+
# Currently, OpenAI only allows for text content in tool results
503+
content=cast(str, x.get_model_value()),
504504
tool_call_id=x.id,
505505
role="tool",
506506
)
@@ -529,7 +529,7 @@ def _as_turn(
529529
contents: list[Content] = []
530530
if message.content is not None:
531531
if has_data_model:
532-
data = json.loads(message.content)
532+
data = orjson.loads(message.content)
533533
contents = [ContentJson(value=data)]
534534
else:
535535
contents = [ContentText(text=message.content)]
@@ -544,8 +544,8 @@ def _as_turn(
544544

545545
args = {}
546546
try:
547-
args = json.loads(func.arguments) if func.arguments else {}
548-
except json.JSONDecodeError:
547+
args = orjson.loads(func.arguments) if func.arguments else {}
548+
except orjson.JSONDecodeError:
549549
raise ValueError(
550550
f"The model's completion included a tool request ({func.name}) "
551551
"with invalid JSON for input arguments: '{func.arguments}'"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies = [
88
"requests",
99
"pydantic>=2.0",
1010
"jinja2",
11+
"orjson",
1112
"rich",
1213
]
1314
classifiers = [

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,8 @@ def assert_pdf_local(chat_fun: ChatFun):
256256
wait=wait_exponential(min=1, max=60),
257257
reraise=True,
258258
)
259+
260+
261+
@pytest.fixture
262+
def test_images_dir():
263+
return Path(__file__).parent / "images"

tests/images/dice.png

219 KB
Loading

0 commit comments

Comments
 (0)