Skip to content

Commit 59fbd8c

Browse files
committed
better testing
1 parent ce455d5 commit 59fbd8c

File tree

6 files changed

+212
-288
lines changed

6 files changed

+212
-288
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from langchain_huggingface.chat_models.huggingface import ( # type: ignore[import-not-found]
2+
TGI_MESSAGE,
3+
TGI_RESPONSE,
24
ChatHuggingFace,
35
_convert_dict_to_message,
46
)
57

6-
__all__ = ["ChatHuggingFace", "_convert_dict_to_message"]
8+
__all__ = ["ChatHuggingFace", "_convert_dict_to_message", "TGI_MESSAGE", "TGI_RESPONSE"]

libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py

+163-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Hugging Face Chat Wrapper."""
22

33
import json
4+
from dataclasses import dataclass
5+
from operator import itemgetter
46
from typing import (
57
Any,
68
AsyncIterator,
@@ -52,7 +54,9 @@
5254
from langchain_core.messages.tool import (
5355
tool_call_chunk as create_tool_call_chunk,
5456
)
57+
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
5558
from langchain_core.output_parsers.openai_tools import (
59+
JsonOutputKeyToolsParser,
5660
make_invalid_tool_call,
5761
parse_tool_call,
5862
)
@@ -62,16 +66,41 @@
6266
ChatResult,
6367
LLMResult,
6468
)
65-
from langchain_core.runnables import Runnable
69+
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
6670
from langchain_core.tools import BaseTool
67-
from langchain_core.utils.function_calling import convert_to_openai_tool
68-
from pydantic import Field, model_validator
71+
from langchain_core.utils.function_calling import (
72+
convert_to_json_schema,
73+
convert_to_openai_tool,
74+
)
75+
from langchain_core.utils.pydantic import is_basemodel_subclass
76+
from pydantic import (
77+
BaseModel,
78+
Field,
79+
model_validator,
80+
)
6981
from typing_extensions import Self
7082

7183
from ..llms.huggingface_endpoint import HuggingFaceEndpoint
7284
from ..llms.huggingface_pipeline import HuggingFacePipeline
7385

7486

87+
@dataclass
88+
class TGI_RESPONSE:
89+
"""Response from the TextGenInference API."""
90+
91+
choices: List[Any]
92+
usage: Dict
93+
94+
95+
@dataclass
96+
class TGI_MESSAGE:
97+
"""Message to send to the TextGenInference API."""
98+
99+
role: str
100+
content: str
101+
tool_calls: List[Dict]
102+
103+
75104
def _lc_tool_call_to_hf_tool_call(tool_call: ToolCall) -> dict:
76105
return {
77106
"type": "function",
@@ -559,6 +588,7 @@ def _generate(
559588
return generate_from_stream(stream_iter)
560589
message_dicts, params = self._create_message_dicts(messages, stop)
561590
params = {
591+
"stop": stop,
562592
**params,
563593
**({"stream": stream} if stream is not None else {}),
564594
**kwargs,
@@ -849,6 +879,136 @@ def bind_tools(
849879
kwargs["tool_choice"] = tool_choice
850880
return super().bind(tools=formatted_tools, **kwargs)
851881

882+
def with_structured_output(
883+
self,
884+
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
885+
*,
886+
method: Literal[
887+
"function_calling", "json_mode", "json_schema"
888+
] = "function_calling",
889+
include_raw: bool = False,
890+
**kwargs: Any,
891+
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
892+
"""Model wrapper that returns outputs formatted to match the given schema.
893+
894+
Args:
895+
schema:
896+
The output schema. Can be passed in as:
897+
- an OpenAI function/tool schema,
898+
- a JSON Schema,
899+
- a TypedDict class (support added in 0.1.7),
900+
901+
Pydantic class is currently supported.
902+
903+
method: The method for steering model generation, one of:
904+
905+
- "function_calling":
906+
Uses Fireworks's `tool-calling features <https://docs.fireworks.ai/guides/function-calling>`_.
907+
- "json_schema":
908+
Uses Fireworks's `structured output feature <https://docs.fireworks.ai/structured-responses/structured-response-formatting>`_.
909+
- "json_mode":
910+
Uses Fireworks's `JSON mode feature <https://docs.fireworks.ai/structured-responses/structured-response-formatting>`_.
911+
912+
include_raw:
913+
If False then only the parsed structured output is returned. If
914+
an error occurs during model output parsing it will be raised. If True
915+
then both the raw model response (a BaseMessage) and the parsed model
916+
response will be returned. If an error occurs during output parsing it
917+
will be caught and returned as well. The final output is always a dict
918+
with keys "raw", "parsed", and "parsing_error".
919+
920+
Returns:
921+
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
922+
923+
If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs
924+
an instance of ``schema`` (i.e., a Pydantic object).
925+
926+
Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
927+
928+
If ``include_raw`` is True, then Runnable outputs a dict with keys:
929+
- ``"raw"``: BaseMessage
930+
- ``"parsed"``: None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
931+
- ``"parsing_error"``: Optional[BaseException]
932+
933+
""" # noqa: E501
934+
_ = kwargs.pop("strict", None)
935+
if kwargs:
936+
raise ValueError(f"Received unsupported arguments {kwargs}")
937+
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
938+
if method == "function_calling":
939+
if schema is None:
940+
raise ValueError(
941+
"schema must be specified when method is 'function_calling'. "
942+
"Received None."
943+
)
944+
formatted_tool = convert_to_openai_tool(schema)
945+
tool_name = formatted_tool["function"]["name"]
946+
llm = self.bind_tools(
947+
[schema],
948+
tool_choice=tool_name,
949+
ls_structured_output_format={
950+
"kwargs": {"method": "function_calling"},
951+
"schema": formatted_tool,
952+
},
953+
)
954+
if is_pydantic_schema:
955+
raise NotImplementedError(
956+
"Pydantic schema is not supported for function calling"
957+
)
958+
else:
959+
output_parser = JsonOutputKeyToolsParser(
960+
key_name=tool_name, first_tool_only=True
961+
)
962+
elif method == "json_schema":
963+
if schema is None:
964+
raise ValueError(
965+
"schema must be specified when method is 'json_schema'. "
966+
"Received None."
967+
)
968+
formatted_schema = convert_to_json_schema(schema)
969+
llm = self.bind(
970+
response_format={"type": "json_object", "schema": formatted_schema},
971+
ls_structured_output_format={
972+
"kwargs": {"method": "json_schema"},
973+
"schema": schema,
974+
},
975+
)
976+
output_parser = (
977+
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
978+
if is_pydantic_schema
979+
else JsonOutputParser()
980+
)
981+
elif method == "json_mode":
982+
llm = self.bind(
983+
response_format={"type": "json_object"},
984+
ls_structured_output_format={
985+
"kwargs": {"method": "json_mode"},
986+
"schema": schema,
987+
},
988+
)
989+
output_parser = (
990+
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
991+
if is_pydantic_schema
992+
else JsonOutputParser()
993+
)
994+
else:
995+
raise ValueError(
996+
f"Unrecognized method argument. Expected one of 'function_calling' or "
997+
f"'json_mode'. Received: '{method}'"
998+
)
999+
1000+
if include_raw:
1001+
parser_assign = RunnablePassthrough.assign(
1002+
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
1003+
)
1004+
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
1005+
parser_with_fallback = parser_assign.with_fallbacks(
1006+
[parser_none], exception_key="parsing_error"
1007+
)
1008+
return RunnableMap(raw=llm) | parser_with_fallback
1009+
else:
1010+
return llm | output_parser
1011+
8521012
def _create_message_dicts(
8531013
self, messages: List[BaseMessage], stop: Optional[List[str]]
8541014
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:

libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -200,36 +200,18 @@ def build_extra(cls, values: Dict[str, Any]) -> Any:
200200
@model_validator(mode="after")
201201
def validate_environment(self) -> Self:
202202
"""Validate that package is installed and that the API token is valid."""
203-
try:
204-
from huggingface_hub import login # type: ignore[import]
205-
206-
except ImportError:
207-
raise ImportError(
208-
"Could not import huggingface_hub python package. "
209-
"Please install it with `pip install huggingface_hub`."
210-
)
211-
212203
huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
213204
"HF_TOKEN"
214205
)
215206

216-
if huggingfacehub_api_token is not None:
217-
try:
218-
login(token=huggingfacehub_api_token)
219-
except Exception as e:
220-
raise ValueError(
221-
"Could not authenticate with huggingface_hub. "
222-
"Please check your API token."
223-
) from e
224-
225207
from huggingface_hub import AsyncInferenceClient, InferenceClient
226208

227209
# Instantiate clients with supported kwargs
228210
sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters)
229211
self.client = InferenceClient(
230212
model=self.model,
231213
timeout=self.timeout,
232-
token=huggingfacehub_api_token,
214+
api_key=huggingfacehub_api_token,
233215
provider=self.provider,
234216
**{
235217
key: value
@@ -242,7 +224,7 @@ def validate_environment(self) -> Self:
242224
self.async_client = AsyncInferenceClient(
243225
model=self.model,
244226
timeout=self.timeout,
245-
token=huggingfacehub_api_token,
227+
api_key=huggingfacehub_api_token,
246228
provider=self.provider,
247229
**{
248230
key: value

0 commit comments

Comments
 (0)