Skip to content

Commit 72be129

Browse files
authored
Add .extract_data() support to ChatSnowflake() (#84)
* Add .extract_data() support to ChatSnowflake() * Update types * Fix import formatting * Simplify
1 parent b351a27 commit 72be129

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2020
### Improvements
2121

2222
* When a tool call ends in failure, a warning is now raised and the stacktrace is printed. (#79)
23-
* `ChatSnowflake()` now supports `async` calls. (#81)
24-
* `ChatSnowflake()` no longer errors due to more than one session being active. (#83)
23+
* Several improvements to `ChatSnowflake()`:
24+
* `.extract_data()` is now supported.
25+
* `async` methods are now supported. (#81)
26+
* Fixed an issue with more than one session being active at once. (#83)
2527

2628
### Changes
2729

chatlas/_snowflake.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import asyncio
2+
import json
23
from typing import TYPE_CHECKING, Iterable, Literal, Optional, TypedDict, cast, overload
34

45
from pydantic import BaseModel
56

67
from ._chat import Chat
7-
from ._content import Content
8+
from ._content import Content, ContentJson, ContentText
89
from ._logging import log_model_default
910
from ._provider import Provider
10-
from ._tools import Tool
11+
from ._tools import Tool, basemodel_to_param_schema
1112
from ._turn import Turn, normalize_turns
1213
from ._utils import drop_none, wrap_async_iterable
1314

@@ -259,20 +260,6 @@ def _chat_perform_args(
259260
data_model: Optional[type[BaseModel]] = None,
260261
kwargs: Optional["SubmitInputArgs"] = None,
261262
):
262-
# Cortex doesn't seem to support tools
263-
if tools:
264-
raise ValueError("Snowflake does not currently support tools.")
265-
266-
# TODO: implement data_model when this PR makes it into snowflake-ml-python
267-
# https://github.com/snowflakedb/snowflake-ml-python/pull/141
268-
# https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#structured-output-example
269-
if data_model:
270-
raise NotImplementedError(
271-
"The snowflake-ml-python package currently doesn't support structured output. "
272-
"Upvote this PR to help prioritize it: "
273-
"https://github.com/snowflakedb/snowflake-ml-python/pull/141"
274-
)
275-
276263
kwargs_full: "SubmitInputArgs" = {
277264
"stream": stream,
278265
"prompt": self._as_prompt_input(turns),
@@ -281,6 +268,23 @@ def _chat_perform_args(
281268
**(kwargs or {}),
282269
}
283270

271+
# TODO: get tools working
272+
if tools:
273+
raise ValueError("Snowflake does not currently support tools.")
274+
275+
if data_model is not None:
276+
params = basemodel_to_param_schema(data_model)
277+
opts = kwargs_full.get("options") or {}
278+
opts["response_format"] = {
279+
"type": "json",
280+
"schema": {
281+
"type": "object",
282+
"properties": params["properties"],
283+
"required": params["required"],
284+
},
285+
}
286+
kwargs_full["options"] = opts
287+
284288
return kwargs_full
285289

286290
def stream_text(self, chunk):
@@ -323,10 +327,18 @@ def _as_prompt_input(self, turns: list[Turn]) -> list["ConversationMessage"]:
323327
res.append(
324328
{
325329
"role": turn.role,
326-
"content": turn.text,
330+
"content": str(turn),
327331
}
328332
)
329333
return res
330334

331335
def _as_turn(self, completion, has_data_model) -> Turn:
332-
return Turn("assistant", completion)
336+
completion = cast(str, completion)
337+
338+
if has_data_model:
339+
data = json.loads(completion)
340+
contents = [ContentJson(value=data)]
341+
else:
342+
contents = [ContentText(text=completion)]
343+
344+
return Turn("assistant", contents)

chatlas/types/openai/_submit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ class SubmitInputArgs(TypedDict, total=False):
3838
model: Union[
3939
str,
4040
Literal[
41+
"gpt-4.1",
42+
"gpt-4.1-mini",
43+
"gpt-4.1-nano",
44+
"gpt-4.1-2025-04-14",
45+
"gpt-4.1-mini-2025-04-14",
46+
"gpt-4.1-nano-2025-04-14",
4147
"o3-mini",
4248
"o3-mini-2025-01-31",
4349
"o1",

0 commit comments

Comments
 (0)