Skip to content

Commit 7fc9e99

Browse files
liuhetianeyurtsevbaskaryan
authored
openai[patch]: get output_type when using with_structured_output (#26307)
- This allows pydantic to correctly resolve annotations necessary when using openai new param `json_schema` Resolves issue: #26250 --------- Co-authored-by: Eugene Yurtsev <[email protected]> Co-authored-by: Bagatur <[email protected]>
1 parent 0f2b32f commit 7fc9e99

File tree

2 files changed

+36
-7
lines changed
  • libs/partners/openai

2 files changed

+36
-7
lines changed

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
from langchain_core.messages.ai import UsageMetadata
6666
from langchain_core.messages.tool import tool_call_chunk
6767
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
68-
from langchain_core.output_parsers.base import OutputParserLike
6968
from langchain_core.output_parsers.openai_tools import (
7069
JsonOutputKeyToolsParser,
7170
PydanticToolsParser,
@@ -1421,7 +1420,7 @@ class AnswerWithJustification(BaseModel):
14211420
strict=strict,
14221421
)
14231422
if is_pydantic_schema:
1424-
output_parser: OutputParserLike = PydanticToolsParser(
1423+
output_parser: Runnable = PydanticToolsParser(
14251424
tools=[schema], # type: ignore[list-item]
14261425
first_tool_only=True, # type: ignore[list-item]
14271426
)
@@ -1445,11 +1444,12 @@ class AnswerWithJustification(BaseModel):
14451444
strict = strict if strict is not None else True
14461445
response_format = _convert_to_openai_response_format(schema, strict=strict)
14471446
llm = self.bind(response_format=response_format)
1448-
output_parser = (
1449-
cast(Runnable, _oai_structured_outputs_parser)
1450-
if is_pydantic_schema
1451-
else JsonOutputParser()
1452-
)
1447+
if is_pydantic_schema:
1448+
output_parser = _oai_structured_outputs_parser.with_types(
1449+
output_type=cast(type, schema)
1450+
)
1451+
else:
1452+
output_parser = JsonOutputParser()
14531453
else:
14541454
raise ValueError(
14551455
f"Unrecognized method argument. Expected one of 'function_calling' or "

libs/partners/openai/tests/unit_tests/chat_models/test_base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from langchain_core.messages.ai import UsageMetadata
2020
from langchain_core.pydantic_v1 import BaseModel
21+
from pydantic import BaseModel as BaseModelV2
2122

2223
from langchain_openai import ChatOpenAI
2324
from langchain_openai.chat_models.base import (
@@ -694,3 +695,31 @@ def test_get_num_tokens_from_messages() -> None:
694695
expected = 176
695696
actual = llm.get_num_tokens_from_messages(messages)
696697
assert expected == actual
698+
699+
700+
class Foo(BaseModel):
701+
bar: int
702+
703+
704+
class FooV2(BaseModelV2):
705+
bar: int
706+
707+
708+
@pytest.mark.parametrize("schema", [Foo, FooV2])
709+
def test_schema_from_with_structured_output(schema: Type) -> None:
710+
"""Test schema from with_structured_output."""
711+
712+
llm = ChatOpenAI()
713+
714+
structured_llm = llm.with_structured_output(
715+
schema, method="json_schema", strict=True
716+
)
717+
718+
expected = {
719+
"properties": {"bar": {"title": "Bar", "type": "integer"}},
720+
"required": ["bar"],
721+
"title": schema.__name__,
722+
"type": "object",
723+
}
724+
actual = structured_llm.get_output_schema().schema()
725+
assert actual == expected

0 commit comments

Comments
 (0)