Skip to content

Commit 12cc078

Browse files
committed
AbsoChatModel working
1 parent e61a981 commit 12cc078

File tree

4 files changed

+397
-12
lines changed

4 files changed

+397
-12
lines changed

langchain/main.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from model import ChatAbso
2+
from langchain_core.messages import (
3+
AIMessage,
4+
AIMessageChunk,
5+
BaseMessage,
6+
HumanMessage
7+
)
8+
9+
10+
abso = ChatAbso(fast_model="gpt-4o", slow_model="o3-mini")
11+
res = abso.invoke([HumanMessage(content="hello")])
12+
print(res.content)
13+
res = abso.invoke([HumanMessage(content="what's the meaning of life")])
14+
print(res.content)

langchain/model.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
from typing import Any, Dict, Iterator, List, Optional, Union, Mapping, cast
2+
import requests
3+
import os
4+
import json
5+
6+
7+
from langchain_core.output_parsers.openai_tools import (
8+
JsonOutputKeyToolsParser,
9+
PydanticToolsParser,
10+
make_invalid_tool_call,
11+
parse_tool_call,
12+
)
13+
14+
from langchain_core.callbacks import (
15+
CallbackManagerForLLMRun,
16+
)
17+
from langchain_core.language_models import BaseChatModel
18+
from langchain_core.messages import (
19+
AIMessage,
20+
AIMessageChunk,
21+
BaseMessage,
22+
InvalidToolCall,
23+
ChatMessage,
24+
HumanMessage,
25+
ToolCall,
26+
ToolMessage,
27+
SystemMessage,
28+
FunctionMessage
29+
30+
)
31+
from langchain_core.messages.ai import UsageMetadata
32+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
33+
from langchain_core.messages.ai import (
34+
InputTokenDetails,
35+
OutputTokenDetails,
36+
UsageMetadata,
37+
)
38+
39+
40+
41+
def _format_message_content(content: Any) -> Any:
42+
"""Format message content."""
43+
if content and isinstance(content, list):
44+
# Remove unexpected block types
45+
formatted_content = []
46+
for block in content:
47+
if (
48+
isinstance(block, dict)
49+
and "type" in block
50+
and block["type"] == "tool_use"
51+
):
52+
continue
53+
else:
54+
formatted_content.append(block)
55+
else:
56+
formatted_content = content
57+
58+
return formatted_content
59+
60+
def _convert_message_to_dict(message: BaseMessage) -> dict:
61+
"""Convert a LangChain message to a dictionary.
62+
63+
Args:
64+
message: The LangChain message.
65+
66+
Returns:
67+
The dictionary.
68+
"""
69+
message_dict: Dict[str, Any] = {"content": _format_message_content(message.content)}
70+
if (name := message.name or message.additional_kwargs.get("name")) is not None:
71+
message_dict["name"] = name
72+
73+
# populate role and additional message data
74+
if isinstance(message, ChatMessage):
75+
message_dict["role"] = message.role
76+
elif isinstance(message, HumanMessage):
77+
message_dict["role"] = "user"
78+
elif isinstance(message, AIMessage):
79+
message_dict["role"] = "assistant"
80+
if "function_call" in message.additional_kwargs:
81+
message_dict["function_call"] = message.additional_kwargs["function_call"]
82+
if message.tool_calls or message.invalid_tool_calls:
83+
message_dict["tool_calls"] = [
84+
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
85+
] + [
86+
_lc_invalid_tool_call_to_openai_tool_call(tc)
87+
for tc in message.invalid_tool_calls
88+
]
89+
elif "tool_calls" in message.additional_kwargs:
90+
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
91+
tool_call_supported_props = {"id", "type", "function"}
92+
message_dict["tool_calls"] = [
93+
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
94+
for tool_call in message_dict["tool_calls"]
95+
]
96+
else:
97+
pass
98+
# If tool calls present, content null value should be None not empty string.
99+
if "function_call" in message_dict or "tool_calls" in message_dict:
100+
message_dict["content"] = message_dict["content"] or None
101+
102+
if "audio" in message.additional_kwargs:
103+
raw_audio = message.additional_kwargs["audio"]
104+
audio = (
105+
{"id": message.additional_kwargs["audio"]["id"]}
106+
if "id" in raw_audio
107+
else raw_audio
108+
)
109+
message_dict["audio"] = audio
110+
elif isinstance(message, SystemMessage):
111+
message_dict["role"] = message.additional_kwargs.get(
112+
"__openai_role__", "system"
113+
)
114+
elif isinstance(message, FunctionMessage):
115+
message_dict["role"] = "function"
116+
elif isinstance(message, ToolMessage):
117+
message_dict["role"] = "tool"
118+
message_dict["tool_call_id"] = message.tool_call_id
119+
120+
supported_props = {"content", "role", "tool_call_id"}
121+
message_dict = {k: v for k, v in message_dict.items() if k in supported_props}
122+
else:
123+
raise TypeError(f"Got unknown type {message}")
124+
return message_dict
125+
126+
127+
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
128+
"""Convert a dictionary to a LangChain message.
129+
130+
Args:
131+
_dict: The dictionary.
132+
133+
Returns:
134+
The LangChain message.
135+
"""
136+
role = _dict.get("role")
137+
name = _dict.get("name")
138+
id_ = _dict.get("id")
139+
if role == "user":
140+
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
141+
elif role == "assistant":
142+
# Fix for azure
143+
# Also OpenAI returns None for tool invocations
144+
content = _dict.get("content", "") or ""
145+
additional_kwargs: Dict = {}
146+
if function_call := _dict.get("function_call"):
147+
additional_kwargs["function_call"] = dict(function_call)
148+
tool_calls = []
149+
invalid_tool_calls = []
150+
if raw_tool_calls := _dict.get("tool_calls"):
151+
additional_kwargs["tool_calls"] = raw_tool_calls
152+
for raw_tool_call in raw_tool_calls:
153+
try:
154+
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
155+
except Exception as e:
156+
invalid_tool_calls.append(
157+
make_invalid_tool_call(raw_tool_call, str(e))
158+
)
159+
if audio := _dict.get("audio"):
160+
additional_kwargs["audio"] = audio
161+
return AIMessage(
162+
content=content,
163+
additional_kwargs=additional_kwargs,
164+
name=name,
165+
id=id_,
166+
tool_calls=tool_calls,
167+
invalid_tool_calls=invalid_tool_calls,
168+
)
169+
elif role in ("system", "developer"):
170+
if role == "developer":
171+
additional_kwargs = {"__openai_role__": role}
172+
else:
173+
additional_kwargs = {}
174+
return SystemMessage(
175+
content=_dict.get("content", ""),
176+
name=name,
177+
id=id_,
178+
additional_kwargs=additional_kwargs,
179+
)
180+
elif role == "function":
181+
return FunctionMessage(
182+
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
183+
)
184+
elif role == "tool":
185+
additional_kwargs = {}
186+
if "name" in _dict:
187+
additional_kwargs["name"] = _dict["name"]
188+
return ToolMessage(
189+
content=_dict.get("content", ""),
190+
tool_call_id=cast(str, _dict.get("tool_call_id")),
191+
additional_kwargs=additional_kwargs,
192+
name=name,
193+
id=id_,
194+
)
195+
else:
196+
return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type]
197+
198+
199+
def _lc_invalid_tool_call_to_openai_tool_call(
200+
invalid_tool_call: InvalidToolCall,
201+
) -> dict:
202+
return {
203+
"type": "function",
204+
"id": invalid_tool_call["id"],
205+
"function": {
206+
"name": invalid_tool_call["name"],
207+
"arguments": invalid_tool_call["args"],
208+
},
209+
}
210+
211+
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
212+
return {
213+
"type": "function",
214+
"id": tool_call["id"],
215+
"function": {
216+
"name": tool_call["name"],
217+
"arguments": json.dumps(tool_call["args"]),
218+
},
219+
}
220+
221+
222+
def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata:
223+
input_tokens = oai_token_usage.get("prompt_tokens", 0)
224+
output_tokens = oai_token_usage.get("completion_tokens", 0)
225+
total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens)
226+
input_token_details: dict = {
227+
"audio": (oai_token_usage.get("prompt_tokens_details") or {}).get(
228+
"audio_tokens"
229+
),
230+
"cache_read": (oai_token_usage.get("prompt_tokens_details") or {}).get(
231+
"cached_tokens"
232+
),
233+
}
234+
output_token_details: dict = {
235+
"audio": (oai_token_usage.get("completion_tokens_details") or {}).get(
236+
"audio_tokens"
237+
),
238+
"reasoning": (oai_token_usage.get("completion_tokens_details") or {}).get(
239+
"reasoning_tokens"
240+
),
241+
}
242+
return UsageMetadata(
243+
input_tokens=input_tokens,
244+
output_tokens=output_tokens,
245+
total_tokens=total_tokens,
246+
input_token_details=InputTokenDetails(
247+
**{k: v for k, v in input_token_details.items() if v is not None}
248+
),
249+
output_token_details=OutputTokenDetails(
250+
**{k: v for k, v in output_token_details.items() if v is not None}
251+
),
252+
)
253+
254+
255+
def _create_chat_result(
256+
response: Union[dict, Any],
257+
generation_info: Optional[Dict] = None,
258+
) -> ChatResult:
259+
generations = []
260+
261+
response_dict = (
262+
response if isinstance(response, dict) else response.model_dump()
263+
)
264+
if response_dict.get("error"):
265+
raise ValueError(response_dict.get("error"))
266+
267+
token_usage = response_dict.get("usage")
268+
for res in response_dict["choices"]:
269+
message = _convert_dict_to_message(res["message"])
270+
if token_usage and isinstance(message, AIMessage):
271+
message.usage_metadata = _create_usage_metadata(token_usage)
272+
generation_info = generation_info or {}
273+
generation_info["finish_reason"] = (
274+
res.get("finish_reason")
275+
if res.get("finish_reason") is not None
276+
else generation_info.get("finish_reason")
277+
)
278+
if "logprobs" in res:
279+
generation_info["logprobs"] = res["logprobs"]
280+
gen = ChatGeneration(message=message, generation_info=generation_info)
281+
generations.append(gen)
282+
llm_output = {
283+
"token_usage": token_usage,
284+
"model_name": response_dict.get("model"),
285+
"system_fingerprint": response_dict.get("system_fingerprint", ""),
286+
}
287+
288+
if getattr(response, "choices", None):
289+
message = response.choices[0].message # type: ignore[attr-defined]
290+
if hasattr(message, "parsed"):
291+
generations[0].message.additional_kwargs["parsed"] = message.parsed
292+
if hasattr(message, "refusal"):
293+
generations[0].message.additional_kwargs["refusal"] = message.refusal
294+
295+
return ChatResult(generations=generations, llm_output=llm_output)
296+
297+
298+
class ChatAbso(BaseChatModel):
299+
"""A smart LLM proxy that automatically routes requests between fast and slow models based on prompt complexity.
300+
It uses several heuristics to determine the complexity of the prompt and routes the request to the appropriate model.
301+
The model used can be specified by the user.
302+
It only supports openai models for now, but will support other models in the future.
303+
You need to have an OpenAI API key to use this model.
304+
305+
.. dropdown:: Setup
306+
:open:
307+
Set then environment variable ``OPENAI_API_KEY``.
308+
309+
.. code-block:: bash
310+
311+
pip install -U langchain-openai
312+
export OPENAI_API_KEY="your-api-key"
313+
314+
.. dropdown:: Usage
315+
Example:
316+
.. code-block:: python
317+
abso = ChatAbso(fast_model="gpt-4o", slow_model="o3-mini")
318+
result = abso.invoke([HumanMessage(content="hello")])
319+
result = model.batch([[HumanMessage(content="What is the meaning of life?")]
320+
"""
321+
322+
fast_model: str
323+
"""The identifier of the fast model used for simple or lower-complexity tasks, ensuring quick response times."""
324+
slow_model: str
325+
"""The identifier of the slow model used for complex or high-accuracy tasks where thorough processing is needed."""
326+
327+
328+
def _generate(
329+
self,
330+
messages: List[BaseMessage],
331+
stop: Optional[List[str]] = None,
332+
**kwargs
333+
) -> ChatResult:
334+
"""
335+
Args:
336+
messages: the prompt composed of a list of messages.
337+
stop: a list of stop tokens that should be respected and, if triggered,
338+
must be included as part of the final output.
339+
"""
340+
payload = {
341+
"messages": [_convert_message_to_dict(message) for message in messages],
342+
"fastModel": self.fast_model,
343+
"slowModel": self.slow_model,
344+
"stream": False,
345+
}
346+
# Pass stop tokens to your backend so that it stops generation appropriately.
347+
if stop is not None:
348+
payload["stop"] = stop
349+
350+
headers = {
351+
'Authorization': f'Bearer {os.environ.get("OPENAI_API_KEY")}',
352+
'Content-Type': 'application/json'
353+
}
354+
355+
response = requests.post(
356+
'http://localhost:8787/v1/chat/completions',
357+
json=payload,
358+
headers=headers,
359+
)
360+
361+
generation_info = {"headers": dict(response.headers)}
362+
result = _create_chat_result(response.json(), generation_info)
363+
364+
if stop is not None and result.generations:
365+
stop_token = stop[0] if isinstance(stop, list) else stop
366+
for generation in result.generations:
367+
if generation.generation_info.get("finish_reason") == "stop":
368+
content = generation.message.content or ""
369+
if not content.endswith(stop_token):
370+
generation.message.content = content + stop_token
371+
372+
return result
373+
@property
374+
def _llm_type(self) -> str:
375+
"""Get the type of language model used by this chat model."""
376+
return "openai-chat"

0 commit comments

Comments
 (0)