1
1
import asyncio
2
+ import json
2
3
from typing import TYPE_CHECKING , Iterable , Literal , Optional , TypedDict , cast , overload
3
4
4
5
from pydantic import BaseModel
5
6
6
7
from ._chat import Chat
7
- from ._content import Content
8
+ from ._content import Content , ContentJson , ContentText
8
9
from ._logging import log_model_default
9
10
from ._provider import Provider
10
- from ._tools import Tool
11
+ from ._tools import Tool , basemodel_to_param_schema
11
12
from ._turn import Turn , normalize_turns
12
13
from ._utils import drop_none , wrap_async_iterable
13
14
@@ -259,20 +260,6 @@ def _chat_perform_args(
259
260
data_model : Optional [type [BaseModel ]] = None ,
260
261
kwargs : Optional ["SubmitInputArgs" ] = None ,
261
262
):
262
- # Cortex doesn't seem to support tools
263
- if tools :
264
- raise ValueError ("Snowflake does not currently support tools." )
265
-
266
- # TODO: implement data_model when this PR makes it into snowflake-ml-python
267
- # https://github.com/snowflakedb/snowflake-ml-python/pull/141
268
- # https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-llm-rest-api#structured-output-example
269
- if data_model :
270
- raise NotImplementedError (
271
- "The snowflake-ml-python package currently doesn't support structured output. "
272
- "Upvote this PR to help prioritize it: "
273
- "https://github.com/snowflakedb/snowflake-ml-python/pull/141"
274
- )
275
-
276
263
kwargs_full : "SubmitInputArgs" = {
277
264
"stream" : stream ,
278
265
"prompt" : self ._as_prompt_input (turns ),
@@ -281,6 +268,23 @@ def _chat_perform_args(
281
268
** (kwargs or {}),
282
269
}
283
270
271
+ # TODO: get tools working
272
+ if tools :
273
+ raise ValueError ("Snowflake does not currently support tools." )
274
+
275
+ if data_model is not None :
276
+ params = basemodel_to_param_schema (data_model )
277
+ opts = kwargs_full .get ("options" ) or {}
278
+ opts ["response_format" ] = {
279
+ "type" : "json" ,
280
+ "schema" : {
281
+ "type" : "object" ,
282
+ "properties" : params ["properties" ],
283
+ "required" : params ["required" ],
284
+ },
285
+ }
286
+ kwargs_full ["options" ] = opts
287
+
284
288
return kwargs_full
285
289
286
290
def stream_text (self , chunk ):
@@ -323,10 +327,18 @@ def _as_prompt_input(self, turns: list[Turn]) -> list["ConversationMessage"]:
323
327
res .append (
324
328
{
325
329
"role" : turn .role ,
326
- "content" : turn . text ,
330
+ "content" : str ( turn ) ,
327
331
}
328
332
)
329
333
return res
330
334
331
335
def _as_turn (self , completion , has_data_model ) -> Turn :
332
- return Turn ("assistant" , completion )
336
+ completion = cast (str , completion )
337
+
338
+ if has_data_model :
339
+ data = json .loads (completion )
340
+ contents = [ContentJson (value = data )]
341
+ else :
342
+ contents = [ContentText (text = completion )]
343
+
344
+ return Turn ("assistant" , contents )
0 commit comments