|
| 1 | +from typing import Any, Dict, Iterator, List, Optional, Union, Mapping, cast |
| 2 | +import requests |
| 3 | +import os |
| 4 | +import json |
| 5 | + |
| 6 | + |
| 7 | +from langchain_core.output_parsers.openai_tools import ( |
| 8 | + JsonOutputKeyToolsParser, |
| 9 | + PydanticToolsParser, |
| 10 | + make_invalid_tool_call, |
| 11 | + parse_tool_call, |
| 12 | +) |
| 13 | + |
| 14 | +from langchain_core.callbacks import ( |
| 15 | + CallbackManagerForLLMRun, |
| 16 | +) |
| 17 | +from langchain_core.language_models import BaseChatModel |
| 18 | +from langchain_core.messages import ( |
| 19 | + AIMessage, |
| 20 | + AIMessageChunk, |
| 21 | + BaseMessage, |
| 22 | + InvalidToolCall, |
| 23 | + ChatMessage, |
| 24 | + HumanMessage, |
| 25 | + ToolCall, |
| 26 | + ToolMessage, |
| 27 | + SystemMessage, |
| 28 | + FunctionMessage |
| 29 | + |
| 30 | +) |
| 31 | +from langchain_core.messages.ai import UsageMetadata |
| 32 | +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult |
| 33 | +from langchain_core.messages.ai import ( |
| 34 | + InputTokenDetails, |
| 35 | + OutputTokenDetails, |
| 36 | + UsageMetadata, |
| 37 | +) |
| 38 | + |
| 39 | + |
| 40 | + |
| 41 | +def _format_message_content(content: Any) -> Any: |
| 42 | + """Format message content.""" |
| 43 | + if content and isinstance(content, list): |
| 44 | + # Remove unexpected block types |
| 45 | + formatted_content = [] |
| 46 | + for block in content: |
| 47 | + if ( |
| 48 | + isinstance(block, dict) |
| 49 | + and "type" in block |
| 50 | + and block["type"] == "tool_use" |
| 51 | + ): |
| 52 | + continue |
| 53 | + else: |
| 54 | + formatted_content.append(block) |
| 55 | + else: |
| 56 | + formatted_content = content |
| 57 | + |
| 58 | + return formatted_content |
| 59 | + |
| 60 | +def _convert_message_to_dict(message: BaseMessage) -> dict: |
| 61 | + """Convert a LangChain message to a dictionary. |
| 62 | +
|
| 63 | + Args: |
| 64 | + message: The LangChain message. |
| 65 | +
|
| 66 | + Returns: |
| 67 | + The dictionary. |
| 68 | + """ |
| 69 | + message_dict: Dict[str, Any] = {"content": _format_message_content(message.content)} |
| 70 | + if (name := message.name or message.additional_kwargs.get("name")) is not None: |
| 71 | + message_dict["name"] = name |
| 72 | + |
| 73 | + # populate role and additional message data |
| 74 | + if isinstance(message, ChatMessage): |
| 75 | + message_dict["role"] = message.role |
| 76 | + elif isinstance(message, HumanMessage): |
| 77 | + message_dict["role"] = "user" |
| 78 | + elif isinstance(message, AIMessage): |
| 79 | + message_dict["role"] = "assistant" |
| 80 | + if "function_call" in message.additional_kwargs: |
| 81 | + message_dict["function_call"] = message.additional_kwargs["function_call"] |
| 82 | + if message.tool_calls or message.invalid_tool_calls: |
| 83 | + message_dict["tool_calls"] = [ |
| 84 | + _lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls |
| 85 | + ] + [ |
| 86 | + _lc_invalid_tool_call_to_openai_tool_call(tc) |
| 87 | + for tc in message.invalid_tool_calls |
| 88 | + ] |
| 89 | + elif "tool_calls" in message.additional_kwargs: |
| 90 | + message_dict["tool_calls"] = message.additional_kwargs["tool_calls"] |
| 91 | + tool_call_supported_props = {"id", "type", "function"} |
| 92 | + message_dict["tool_calls"] = [ |
| 93 | + {k: v for k, v in tool_call.items() if k in tool_call_supported_props} |
| 94 | + for tool_call in message_dict["tool_calls"] |
| 95 | + ] |
| 96 | + else: |
| 97 | + pass |
| 98 | + # If tool calls present, content null value should be None not empty string. |
| 99 | + if "function_call" in message_dict or "tool_calls" in message_dict: |
| 100 | + message_dict["content"] = message_dict["content"] or None |
| 101 | + |
| 102 | + if "audio" in message.additional_kwargs: |
| 103 | + raw_audio = message.additional_kwargs["audio"] |
| 104 | + audio = ( |
| 105 | + {"id": message.additional_kwargs["audio"]["id"]} |
| 106 | + if "id" in raw_audio |
| 107 | + else raw_audio |
| 108 | + ) |
| 109 | + message_dict["audio"] = audio |
| 110 | + elif isinstance(message, SystemMessage): |
| 111 | + message_dict["role"] = message.additional_kwargs.get( |
| 112 | + "__openai_role__", "system" |
| 113 | + ) |
| 114 | + elif isinstance(message, FunctionMessage): |
| 115 | + message_dict["role"] = "function" |
| 116 | + elif isinstance(message, ToolMessage): |
| 117 | + message_dict["role"] = "tool" |
| 118 | + message_dict["tool_call_id"] = message.tool_call_id |
| 119 | + |
| 120 | + supported_props = {"content", "role", "tool_call_id"} |
| 121 | + message_dict = {k: v for k, v in message_dict.items() if k in supported_props} |
| 122 | + else: |
| 123 | + raise TypeError(f"Got unknown type {message}") |
| 124 | + return message_dict |
| 125 | + |
| 126 | + |
| 127 | +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: |
| 128 | + """Convert a dictionary to a LangChain message. |
| 129 | +
|
| 130 | + Args: |
| 131 | + _dict: The dictionary. |
| 132 | +
|
| 133 | + Returns: |
| 134 | + The LangChain message. |
| 135 | + """ |
| 136 | + role = _dict.get("role") |
| 137 | + name = _dict.get("name") |
| 138 | + id_ = _dict.get("id") |
| 139 | + if role == "user": |
| 140 | + return HumanMessage(content=_dict.get("content", ""), id=id_, name=name) |
| 141 | + elif role == "assistant": |
| 142 | + # Fix for azure |
| 143 | + # Also OpenAI returns None for tool invocations |
| 144 | + content = _dict.get("content", "") or "" |
| 145 | + additional_kwargs: Dict = {} |
| 146 | + if function_call := _dict.get("function_call"): |
| 147 | + additional_kwargs["function_call"] = dict(function_call) |
| 148 | + tool_calls = [] |
| 149 | + invalid_tool_calls = [] |
| 150 | + if raw_tool_calls := _dict.get("tool_calls"): |
| 151 | + additional_kwargs["tool_calls"] = raw_tool_calls |
| 152 | + for raw_tool_call in raw_tool_calls: |
| 153 | + try: |
| 154 | + tool_calls.append(parse_tool_call(raw_tool_call, return_id=True)) |
| 155 | + except Exception as e: |
| 156 | + invalid_tool_calls.append( |
| 157 | + make_invalid_tool_call(raw_tool_call, str(e)) |
| 158 | + ) |
| 159 | + if audio := _dict.get("audio"): |
| 160 | + additional_kwargs["audio"] = audio |
| 161 | + return AIMessage( |
| 162 | + content=content, |
| 163 | + additional_kwargs=additional_kwargs, |
| 164 | + name=name, |
| 165 | + id=id_, |
| 166 | + tool_calls=tool_calls, |
| 167 | + invalid_tool_calls=invalid_tool_calls, |
| 168 | + ) |
| 169 | + elif role in ("system", "developer"): |
| 170 | + if role == "developer": |
| 171 | + additional_kwargs = {"__openai_role__": role} |
| 172 | + else: |
| 173 | + additional_kwargs = {} |
| 174 | + return SystemMessage( |
| 175 | + content=_dict.get("content", ""), |
| 176 | + name=name, |
| 177 | + id=id_, |
| 178 | + additional_kwargs=additional_kwargs, |
| 179 | + ) |
| 180 | + elif role == "function": |
| 181 | + return FunctionMessage( |
| 182 | + content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_ |
| 183 | + ) |
| 184 | + elif role == "tool": |
| 185 | + additional_kwargs = {} |
| 186 | + if "name" in _dict: |
| 187 | + additional_kwargs["name"] = _dict["name"] |
| 188 | + return ToolMessage( |
| 189 | + content=_dict.get("content", ""), |
| 190 | + tool_call_id=cast(str, _dict.get("tool_call_id")), |
| 191 | + additional_kwargs=additional_kwargs, |
| 192 | + name=name, |
| 193 | + id=id_, |
| 194 | + ) |
| 195 | + else: |
| 196 | + return ChatMessage(content=_dict.get("content", ""), role=role, id=id_) # type: ignore[arg-type] |
| 197 | + |
| 198 | + |
| 199 | +def _lc_invalid_tool_call_to_openai_tool_call( |
| 200 | + invalid_tool_call: InvalidToolCall, |
| 201 | +) -> dict: |
| 202 | + return { |
| 203 | + "type": "function", |
| 204 | + "id": invalid_tool_call["id"], |
| 205 | + "function": { |
| 206 | + "name": invalid_tool_call["name"], |
| 207 | + "arguments": invalid_tool_call["args"], |
| 208 | + }, |
| 209 | + } |
| 210 | + |
| 211 | +def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict: |
| 212 | + return { |
| 213 | + "type": "function", |
| 214 | + "id": tool_call["id"], |
| 215 | + "function": { |
| 216 | + "name": tool_call["name"], |
| 217 | + "arguments": json.dumps(tool_call["args"]), |
| 218 | + }, |
| 219 | + } |
| 220 | + |
| 221 | + |
| 222 | +def _create_usage_metadata(oai_token_usage: dict) -> UsageMetadata: |
| 223 | + input_tokens = oai_token_usage.get("prompt_tokens", 0) |
| 224 | + output_tokens = oai_token_usage.get("completion_tokens", 0) |
| 225 | + total_tokens = oai_token_usage.get("total_tokens", input_tokens + output_tokens) |
| 226 | + input_token_details: dict = { |
| 227 | + "audio": (oai_token_usage.get("prompt_tokens_details") or {}).get( |
| 228 | + "audio_tokens" |
| 229 | + ), |
| 230 | + "cache_read": (oai_token_usage.get("prompt_tokens_details") or {}).get( |
| 231 | + "cached_tokens" |
| 232 | + ), |
| 233 | + } |
| 234 | + output_token_details: dict = { |
| 235 | + "audio": (oai_token_usage.get("completion_tokens_details") or {}).get( |
| 236 | + "audio_tokens" |
| 237 | + ), |
| 238 | + "reasoning": (oai_token_usage.get("completion_tokens_details") or {}).get( |
| 239 | + "reasoning_tokens" |
| 240 | + ), |
| 241 | + } |
| 242 | + return UsageMetadata( |
| 243 | + input_tokens=input_tokens, |
| 244 | + output_tokens=output_tokens, |
| 245 | + total_tokens=total_tokens, |
| 246 | + input_token_details=InputTokenDetails( |
| 247 | + **{k: v for k, v in input_token_details.items() if v is not None} |
| 248 | + ), |
| 249 | + output_token_details=OutputTokenDetails( |
| 250 | + **{k: v for k, v in output_token_details.items() if v is not None} |
| 251 | + ), |
| 252 | + ) |
| 253 | + |
| 254 | + |
| 255 | +def _create_chat_result( |
| 256 | + response: Union[dict, Any], |
| 257 | + generation_info: Optional[Dict] = None, |
| 258 | +) -> ChatResult: |
| 259 | + generations = [] |
| 260 | + |
| 261 | + response_dict = ( |
| 262 | + response if isinstance(response, dict) else response.model_dump() |
| 263 | + ) |
| 264 | + if response_dict.get("error"): |
| 265 | + raise ValueError(response_dict.get("error")) |
| 266 | + |
| 267 | + token_usage = response_dict.get("usage") |
| 268 | + for res in response_dict["choices"]: |
| 269 | + message = _convert_dict_to_message(res["message"]) |
| 270 | + if token_usage and isinstance(message, AIMessage): |
| 271 | + message.usage_metadata = _create_usage_metadata(token_usage) |
| 272 | + generation_info = generation_info or {} |
| 273 | + generation_info["finish_reason"] = ( |
| 274 | + res.get("finish_reason") |
| 275 | + if res.get("finish_reason") is not None |
| 276 | + else generation_info.get("finish_reason") |
| 277 | + ) |
| 278 | + if "logprobs" in res: |
| 279 | + generation_info["logprobs"] = res["logprobs"] |
| 280 | + gen = ChatGeneration(message=message, generation_info=generation_info) |
| 281 | + generations.append(gen) |
| 282 | + llm_output = { |
| 283 | + "token_usage": token_usage, |
| 284 | + "model_name": response_dict.get("model"), |
| 285 | + "system_fingerprint": response_dict.get("system_fingerprint", ""), |
| 286 | + } |
| 287 | + |
| 288 | + if getattr(response, "choices", None): |
| 289 | + message = response.choices[0].message # type: ignore[attr-defined] |
| 290 | + if hasattr(message, "parsed"): |
| 291 | + generations[0].message.additional_kwargs["parsed"] = message.parsed |
| 292 | + if hasattr(message, "refusal"): |
| 293 | + generations[0].message.additional_kwargs["refusal"] = message.refusal |
| 294 | + |
| 295 | + return ChatResult(generations=generations, llm_output=llm_output) |
| 296 | + |
| 297 | + |
| 298 | +class ChatAbso(BaseChatModel): |
| 299 | + """A smart LLM proxy that automatically routes requests between fast and slow models based on prompt complexity. |
| 300 | + It uses several heuristics to determine the complexity of the prompt and routes the request to the appropriate model. |
| 301 | + The model used can be specified by the user. |
| 302 | + It only supports openai models for now, but will support other models in the future. |
| 303 | + You need to have an OpenAI API key to use this model. |
| 304 | +
|
| 305 | + .. dropdown:: Setup |
| 306 | + :open: |
| 307 | + Set then environment variable ``OPENAI_API_KEY``. |
| 308 | +
|
| 309 | + .. code-block:: bash |
| 310 | +
|
| 311 | + pip install -U langchain-openai |
| 312 | + export OPENAI_API_KEY="your-api-key" |
| 313 | +
|
| 314 | + .. dropdown:: Usage |
| 315 | + Example: |
| 316 | + .. code-block:: python |
| 317 | + abso = ChatAbso(fast_model="gpt-4o", slow_model="o3-mini") |
| 318 | + result = abso.invoke([HumanMessage(content="hello")]) |
| 319 | + result = model.batch([[HumanMessage(content="What is the meaning of life?")] |
| 320 | + """ |
| 321 | + |
| 322 | + fast_model: str |
| 323 | + """The identifier of the fast model used for simple or lower-complexity tasks, ensuring quick response times.""" |
| 324 | + slow_model: str |
| 325 | + """The identifier of the slow model used for complex or high-accuracy tasks where thorough processing is needed.""" |
| 326 | + |
| 327 | + |
| 328 | + def _generate( |
| 329 | + self, |
| 330 | + messages: List[BaseMessage], |
| 331 | + stop: Optional[List[str]] = None, |
| 332 | + **kwargs |
| 333 | + ) -> ChatResult: |
| 334 | + """ |
| 335 | + Args: |
| 336 | + messages: the prompt composed of a list of messages. |
| 337 | + stop: a list of stop tokens that should be respected and, if triggered, |
| 338 | + must be included as part of the final output. |
| 339 | + """ |
| 340 | + payload = { |
| 341 | + "messages": [_convert_message_to_dict(message) for message in messages], |
| 342 | + "fastModel": self.fast_model, |
| 343 | + "slowModel": self.slow_model, |
| 344 | + "stream": False, |
| 345 | + } |
| 346 | + # Pass stop tokens to your backend so that it stops generation appropriately. |
| 347 | + if stop is not None: |
| 348 | + payload["stop"] = stop |
| 349 | + |
| 350 | + headers = { |
| 351 | + 'Authorization': f'Bearer {os.environ.get("OPENAI_API_KEY")}', |
| 352 | + 'Content-Type': 'application/json' |
| 353 | + } |
| 354 | + |
| 355 | + response = requests.post( |
| 356 | + 'http://localhost:8787/v1/chat/completions', |
| 357 | + json=payload, |
| 358 | + headers=headers, |
| 359 | + ) |
| 360 | + |
| 361 | + generation_info = {"headers": dict(response.headers)} |
| 362 | + result = _create_chat_result(response.json(), generation_info) |
| 363 | + |
| 364 | + if stop is not None and result.generations: |
| 365 | + stop_token = stop[0] if isinstance(stop, list) else stop |
| 366 | + for generation in result.generations: |
| 367 | + if generation.generation_info.get("finish_reason") == "stop": |
| 368 | + content = generation.message.content or "" |
| 369 | + if not content.endswith(stop_token): |
| 370 | + generation.message.content = content + stop_token |
| 371 | + |
| 372 | + return result |
| 373 | + @property |
| 374 | + def _llm_type(self) -> str: |
| 375 | + """Get the type of language model used by this chat model.""" |
| 376 | + return "openai-chat" |
0 commit comments