Skip to content

Add .extract_data() support to ChatSnowflake() #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Improvements

* When a tool call ends in failure, a warning is now raised and the stacktrace is printed. (#79)
* `ChatSnowflake()` now supports `async` calls. (#81)
* `ChatSnowflake()` no longer errors due to more than one session being active. (#83)
* Several improvements to `ChatSnowflake()`:
* `.extract_data()` is now supported.
* `async` methods are now supported. (#81)
* Fixed an issue with more than one session being active at once. (#83)

### Changes

Expand Down
48 changes: 30 additions & 18 deletions chatlas/_snowflake.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import asyncio
import json
from typing import TYPE_CHECKING, Iterable, Literal, Optional, TypedDict, cast, overload

from pydantic import BaseModel

from ._chat import Chat
from ._content import Content
from ._content import Content, ContentJson, ContentText
from ._logging import log_model_default
from ._provider import Provider
from ._tools import Tool
from ._tools import Tool, basemodel_to_param_schema
from ._turn import Turn, normalize_turns
from ._utils import drop_none, wrap_async_iterable

Expand Down Expand Up @@ -259,20 +260,6 @@ def _chat_perform_args(
data_model: Optional[type[BaseModel]] = None,
kwargs: Optional["SubmitInputArgs"] = None,
):
# Cortex doesn't seem to support tools
if tools:
raise ValueError("Snowflake does not currently support tools.")

# TODO: implement data_model when this PR makes it into snowflake-ml-python
# https://github.com/snowflakedb/snowflake-ml-python/pull/141
# https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#structured-output-example
if data_model:
raise NotImplementedError(
"The snowflake-ml-python package currently doesn't support structured output. "
"Upvote this PR to help prioritize it: "
"https://github.com/snowflakedb/snowflake-ml-python/pull/141"
)

kwargs_full: "SubmitInputArgs" = {
"stream": stream,
"prompt": self._as_prompt_input(turns),
Expand All @@ -281,6 +268,23 @@ def _chat_perform_args(
**(kwargs or {}),
}

# TODO: get tools working
if tools:
raise ValueError("Snowflake does not currently support tools.")

if data_model is not None:
params = basemodel_to_param_schema(data_model)
opts = kwargs_full.get("options") or {}
opts["response_format"] = {
"type": "json",
"schema": {
"type": "object",
"properties": params["properties"],
"required": params["required"],
},
}
kwargs_full["options"] = opts

return kwargs_full

def stream_text(self, chunk):
Expand Down Expand Up @@ -323,10 +327,18 @@ def _as_prompt_input(self, turns: list[Turn]) -> list["ConversationMessage"]:
res.append(
{
"role": turn.role,
"content": turn.text,
"content": str(turn),
}
)
return res

def _as_turn(self, completion, has_data_model) -> Turn:
return Turn("assistant", completion)
completion = cast(str, completion)

if has_data_model:
data = json.loads(completion)
contents = [ContentJson(value=data)]
else:
contents = [ContentText(text=completion)]

return Turn("assistant", contents)
6 changes: 6 additions & 0 deletions chatlas/types/openai/_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class SubmitInputArgs(TypedDict, total=False):
model: Union[
str,
Literal[
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
"gpt-4.1-2025-04-14",
"gpt-4.1-mini-2025-04-14",
"gpt-4.1-nano-2025-04-14",
"o3-mini",
"o3-mini-2025-01-31",
"o1",
Expand Down