From 0f8875971a0e4ac6f27640ca6f73d48cf99f1667 Mon Sep 17 00:00:00 2001 From: Carson Date: Mon, 14 Apr 2025 12:53:06 -0500 Subject: [PATCH 1/4] Add .extract_data() support to ChatSnowflake() --- CHANGELOG.md | 6 +++-- chatlas/_snowflake.py | 52 +++++++++++++++++++++++++++---------------- 2 files changed, 37 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d11bfed3..4b436deb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,8 +20,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Improvements * When a tool call ends in failure, a warning is now raised and the stacktrace is printed. (#79) -* `ChatSnowflake()` now supports `async` calls. (#81) -* `ChatSnowflake()` no longer errors due to more than one session being active. (#83) +* Several improvements to `ChatSnowflake()`: + * `.extract_data()` is now supported. + * `async` methods are now supported. (#81) + * Fixed an issue with more than one session being active at once. (#83) ### Changes diff --git a/chatlas/_snowflake.py b/chatlas/_snowflake.py index 7aed37db..0433b875 100644 --- a/chatlas/_snowflake.py +++ b/chatlas/_snowflake.py @@ -1,17 +1,20 @@ import asyncio -from typing import TYPE_CHECKING, Iterable, Literal, Optional, TypedDict, cast, overload +import json +from typing import (TYPE_CHECKING, Iterable, Literal, Optional, TypedDict, + cast, overload) from pydantic import BaseModel from ._chat import Chat -from ._content import Content +from ._content import Content, ContentJson, ContentText from ._logging import log_model_default from ._provider import Provider -from ._tools import Tool +from ._tools import Tool, basemodel_to_param_schema from ._turn import Turn, normalize_turns from ._utils import drop_none, wrap_async_iterable if TYPE_CHECKING: + from snowflake.cortex._complete import CompleteOptions from snowflake.snowpark import Column # Types inferred from the return type of the `snowflake.cortex.complete` function @@ -259,20 +262,6 @@ def _chat_perform_args( data_model: Optional[type[BaseModel]] = None, kwargs: Optional["SubmitInputArgs"] = None, ): - # Cortex doesn't seem to support tools - if tools: - raise ValueError("Snowflake does not currently support tools.") - - # TODO: implement data_model when this PR makes it into snowflake-ml-python - # https://github.com/snowflakedb/snowflake-ml-python/pull/141 - # https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#structured-output-example - if data_model: - raise NotImplementedError( - "The snowflake-ml-python package currently doesn't support structured output. " - "Upvote this PR to help prioritize it: " - "https://github.com/snowflakedb/snowflake-ml-python/pull/141" - ) - kwargs_full: "SubmitInputArgs" = { "stream": stream, "prompt": self._as_prompt_input(turns), @@ -281,6 +270,23 @@ def _chat_perform_args( **(kwargs or {}), } + # TODO: get tools working + if tools: + raise ValueError("Snowflake does not currently support tools.") + + if data_model is not None: + params = basemodel_to_param_schema(data_model) + opts: CompleteOptions = kwargs_full.get("options") or {} + opts["response_format"] = { + "type": "json", + "schema": { + "type": "object", + "properties": params["properties"], + "required": params["required"], + }, + } + kwargs_full["options"] = opts + return kwargs_full def stream_text(self, chunk): @@ -323,10 +329,18 @@ def _as_prompt_input(self, turns: list[Turn]) -> list["ConversationMessage"]: res.append( { "role": turn.role, - "content": turn.text, + "content": str(turn), } ) return res def _as_turn(self, completion, has_data_model) -> Turn: - return Turn("assistant", completion) + completion = cast(str, completion) + + if has_data_model: + data = json.loads(completion) + contents = [ContentJson(value=data)] + else: + contents = [ContentText(text=completion)] + + return Turn("assistant", contents) From cb39206158229cb81ef4a77111cd666709e36ef2 Mon Sep 17 00:00:00 2001 From: Carson Date: Mon, 14 Apr 2025 15:11:34 -0500 Subject: [PATCH 2/4] Update types --- chatlas/types/openai/_submit.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chatlas/types/openai/_submit.py b/chatlas/types/openai/_submit.py index f0e90285..75efd761 100644 --- a/chatlas/types/openai/_submit.py +++ b/chatlas/types/openai/_submit.py @@ -38,6 +38,12 @@ class SubmitInputArgs(TypedDict, total=False): model: Union[ str, Literal[ + "gpt-4.1", + "gpt-4.1-mini", + "gpt-4.1-nano", + "gpt-4.1-2025-04-14", + "gpt-4.1-mini-2025-04-14", + "gpt-4.1-nano-2025-04-14", "o3-mini", "o3-mini-2025-01-31", "o1", From c7c8156f952bedde87ff67fd9d2536f40791f143 Mon Sep 17 00:00:00 2001 From: Carson Date: Mon, 14 Apr 2025 15:15:00 -0500 Subject: [PATCH 3/4] Fix import formatting --- chatlas/_snowflake.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chatlas/_snowflake.py b/chatlas/_snowflake.py index 0433b875..6023bc86 100644 --- a/chatlas/_snowflake.py +++ b/chatlas/_snowflake.py @@ -1,7 +1,6 @@ import asyncio import json -from typing import (TYPE_CHECKING, Iterable, Literal, Optional, TypedDict, - cast, overload) +from typing import TYPE_CHECKING, Iterable, Literal, Optional, TypedDict, cast, overload from pydantic import BaseModel From 17c51f6b32f5fe0bdd016a27036d72badfc424bc Mon Sep 17 00:00:00 2001 From: Carson Date: Mon, 14 Apr 2025 15:18:38 -0500 Subject: [PATCH 4/4] Simplify --- chatlas/_snowflake.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chatlas/_snowflake.py b/chatlas/_snowflake.py index 6023bc86..6999edbd 100644 --- a/chatlas/_snowflake.py +++ b/chatlas/_snowflake.py @@ -13,7 +13,6 @@ from ._utils import drop_none, wrap_async_iterable if TYPE_CHECKING: - from snowflake.cortex._complete import CompleteOptions from snowflake.snowpark import Column # Types inferred from the return type of the `snowflake.cortex.complete` function @@ -275,7 +274,7 @@ def _chat_perform_args( if data_model is not None: params = basemodel_to_param_schema(data_model) - opts: CompleteOptions = kwargs_full.get("options") or {} + opts = kwargs_full.get("options") or {} opts["response_format"] = { "type": "json", "schema": {