|
1 | 1 | """Hugging Face Chat Wrapper."""
|
2 | 2 |
|
3 | 3 | import json
|
| 4 | +from dataclasses import dataclass |
| 5 | +from operator import itemgetter |
4 | 6 | from typing import (
|
5 | 7 | Any,
|
6 | 8 | AsyncIterator,
|
|
52 | 54 | from langchain_core.messages.tool import (
|
53 | 55 | tool_call_chunk as create_tool_call_chunk,
|
54 | 56 | )
|
| 57 | +from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser |
55 | 58 | from langchain_core.output_parsers.openai_tools import (
|
| 59 | + JsonOutputKeyToolsParser, |
56 | 60 | make_invalid_tool_call,
|
57 | 61 | parse_tool_call,
|
58 | 62 | )
|
|
62 | 66 | ChatResult,
|
63 | 67 | LLMResult,
|
64 | 68 | )
|
65 |
| -from langchain_core.runnables import Runnable |
| 69 | +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough |
66 | 70 | from langchain_core.tools import BaseTool
|
67 |
| -from langchain_core.utils.function_calling import convert_to_openai_tool |
68 |
| -from pydantic import Field, model_validator |
| 71 | +from langchain_core.utils.function_calling import ( |
| 72 | + convert_to_json_schema, |
| 73 | + convert_to_openai_tool, |
| 74 | +) |
| 75 | +from langchain_core.utils.pydantic import is_basemodel_subclass |
| 76 | +from pydantic import ( |
| 77 | + BaseModel, |
| 78 | + Field, |
| 79 | + model_validator, |
| 80 | +) |
69 | 81 | from typing_extensions import Self
|
70 | 82 |
|
71 | 83 | from ..llms.huggingface_endpoint import HuggingFaceEndpoint
|
72 | 84 | from ..llms.huggingface_pipeline import HuggingFacePipeline
|
73 | 85 |
|
74 | 86 |
|
| 87 | +@dataclass |
| 88 | +class TGI_RESPONSE: |
| 89 | + """Response from the TextGenInference API.""" |
| 90 | + |
| 91 | + choices: List[Any] |
| 92 | + usage: Dict |
| 93 | + |
| 94 | + |
| 95 | +@dataclass |
| 96 | +class TGI_MESSAGE: |
| 97 | + """Message to send to the TextGenInference API.""" |
| 98 | + |
| 99 | + role: str |
| 100 | + content: str |
| 101 | + tool_calls: List[Dict] |
| 102 | + |
| 103 | + |
75 | 104 | def _lc_tool_call_to_hf_tool_call(tool_call: ToolCall) -> dict:
|
76 | 105 | return {
|
77 | 106 | "type": "function",
|
@@ -559,6 +588,7 @@ def _generate(
|
559 | 588 | return generate_from_stream(stream_iter)
|
560 | 589 | message_dicts, params = self._create_message_dicts(messages, stop)
|
561 | 590 | params = {
|
| 591 | + "stop": stop, |
562 | 592 | **params,
|
563 | 593 | **({"stream": stream} if stream is not None else {}),
|
564 | 594 | **kwargs,
|
@@ -849,6 +879,136 @@ def bind_tools(
|
849 | 879 | kwargs["tool_choice"] = tool_choice
|
850 | 880 | return super().bind(tools=formatted_tools, **kwargs)
|
851 | 881 |
|
| 882 | + def with_structured_output( |
| 883 | + self, |
| 884 | + schema: Optional[Union[Dict, Type[BaseModel]]] = None, |
| 885 | + *, |
| 886 | + method: Literal[ |
| 887 | + "function_calling", "json_mode", "json_schema" |
| 888 | + ] = "function_calling", |
| 889 | + include_raw: bool = False, |
| 890 | + **kwargs: Any, |
| 891 | + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: |
| 892 | + """Model wrapper that returns outputs formatted to match the given schema. |
| 893 | +
|
| 894 | + Args: |
| 895 | + schema: |
| 896 | + The output schema. Can be passed in as: |
| 897 | + - an OpenAI function/tool schema, |
| 898 | + - a JSON Schema, |
| 899 | + - a TypedDict class (support added in 0.1.7), |
| 900 | +
|
| 901 | + Pydantic class is currently supported. |
| 902 | +
|
| 903 | + method: The method for steering model generation, one of: |
| 904 | +
|
| 905 | + - "function_calling": |
| 906 | + Uses Fireworks's `tool-calling features <https://docs.fireworks.ai/guides/function-calling>`_. |
| 907 | + - "json_schema": |
| 908 | + Uses Fireworks's `structured output feature <https://docs.fireworks.ai/structured-responses/structured-response-formatting>`_. |
| 909 | + - "json_mode": |
| 910 | + Uses Fireworks's `JSON mode feature <https://docs.fireworks.ai/structured-responses/structured-response-formatting>`_. |
| 911 | +
|
| 912 | + include_raw: |
| 913 | + If False then only the parsed structured output is returned. If |
| 914 | + an error occurs during model output parsing it will be raised. If True |
| 915 | + then both the raw model response (a BaseMessage) and the parsed model |
| 916 | + response will be returned. If an error occurs during output parsing it |
| 917 | + will be caught and returned as well. The final output is always a dict |
| 918 | + with keys "raw", "parsed", and "parsing_error". |
| 919 | +
|
| 920 | + Returns: |
| 921 | + A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. |
| 922 | +
|
| 923 | + If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs |
| 924 | + an instance of ``schema`` (i.e., a Pydantic object). |
| 925 | +
|
| 926 | + Otherwise, if ``include_raw`` is False then Runnable outputs a dict. |
| 927 | +
|
| 928 | + If ``include_raw`` is True, then Runnable outputs a dict with keys: |
| 929 | + - ``"raw"``: BaseMessage |
| 930 | + - ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. |
| 931 | + - ``"parsing_error"``: Optional[BaseException] |
| 932 | +
|
| 933 | + """ # noqa: E501 |
| 934 | + _ = kwargs.pop("strict", None) |
| 935 | + if kwargs: |
| 936 | + raise ValueError(f"Received unsupported arguments {kwargs}") |
| 937 | + is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema) |
| 938 | + if method == "function_calling": |
| 939 | + if schema is None: |
| 940 | + raise ValueError( |
| 941 | + "schema must be specified when method is 'function_calling'. " |
| 942 | + "Received None." |
| 943 | + ) |
| 944 | + formatted_tool = convert_to_openai_tool(schema) |
| 945 | + tool_name = formatted_tool["function"]["name"] |
| 946 | + llm = self.bind_tools( |
| 947 | + [schema], |
| 948 | + tool_choice=tool_name, |
| 949 | + ls_structured_output_format={ |
| 950 | + "kwargs": {"method": "function_calling"}, |
| 951 | + "schema": formatted_tool, |
| 952 | + }, |
| 953 | + ) |
| 954 | + if is_pydantic_schema: |
| 955 | + raise NotImplementedError( |
| 956 | + "Pydantic schema is not supported for function calling" |
| 957 | + ) |
| 958 | + else: |
| 959 | + output_parser = JsonOutputKeyToolsParser( |
| 960 | + key_name=tool_name, first_tool_only=True |
| 961 | + ) |
| 962 | + elif method == "json_schema": |
| 963 | + if schema is None: |
| 964 | + raise ValueError( |
| 965 | + "schema must be specified when method is 'json_schema'. " |
| 966 | + "Received None." |
| 967 | + ) |
| 968 | + formatted_schema = convert_to_json_schema(schema) |
| 969 | + llm = self.bind( |
| 970 | + response_format={"type": "json_object", "schema": formatted_schema}, |
| 971 | + ls_structured_output_format={ |
| 972 | + "kwargs": {"method": "json_schema"}, |
| 973 | + "schema": schema, |
| 974 | + }, |
| 975 | + ) |
| 976 | + output_parser = ( |
| 977 | + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] |
| 978 | + if is_pydantic_schema |
| 979 | + else JsonOutputParser() |
| 980 | + ) |
| 981 | + elif method == "json_mode": |
| 982 | + llm = self.bind( |
| 983 | + response_format={"type": "json_object"}, |
| 984 | + ls_structured_output_format={ |
| 985 | + "kwargs": {"method": "json_mode"}, |
| 986 | + "schema": schema, |
| 987 | + }, |
| 988 | + ) |
| 989 | + output_parser = ( |
| 990 | + PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] |
| 991 | + if is_pydantic_schema |
| 992 | + else JsonOutputParser() |
| 993 | + ) |
| 994 | + else: |
| 995 | + raise ValueError( |
| 996 | + f"Unrecognized method argument. Expected one of 'function_calling' or " |
| 997 | + f"'json_mode'. Received: '{method}'" |
| 998 | + ) |
| 999 | + |
| 1000 | + if include_raw: |
| 1001 | + parser_assign = RunnablePassthrough.assign( |
| 1002 | + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None |
| 1003 | + ) |
| 1004 | + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) |
| 1005 | + parser_with_fallback = parser_assign.with_fallbacks( |
| 1006 | + [parser_none], exception_key="parsing_error" |
| 1007 | + ) |
| 1008 | + return RunnableMap(raw=llm) | parser_with_fallback |
| 1009 | + else: |
| 1010 | + return llm | output_parser |
| 1011 | + |
852 | 1012 | def _create_message_dicts(
|
853 | 1013 | self, messages: List[BaseMessage], stop: Optional[List[str]]
|
854 | 1014 | ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
|
0 commit comments