|
3 | 3 | import os
|
4 | 4 | from typing import Any, Optional
|
5 | 5 |
|
| 6 | +from langchain_anthropic import ChatAnthropic |
6 | 7 | from langchain_aws import ChatBedrock
|
7 | 8 | from langchain_community.chat_models.fake import FakeListChatModel
|
8 | 9 | from langchain_core.language_models.base import LanguageModelInput
|
|
11 | 12 | from langchain_core.runnables import ConfigurableField, RunnableConfig
|
12 | 13 | from langchain_deepseek import ChatDeepSeek
|
13 | 14 | from langchain_google_genai import ChatGoogleGenerativeAI
|
| 15 | +from langchain_groq import ChatGroq # trunk-ignore(cspell) |
14 | 16 | from langchain_ollama import ChatOllama
|
15 | 17 | from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
16 | 18 | from opentelemetry import trace
|
@@ -173,6 +175,30 @@ def _get_request_payload(
|
173 | 175 | model_args = deep_update(defaults, config.args)
|
174 | 176 | model_id = model_args["model"]
|
175 | 177 |
|
| 178 | + case "ChatAnthropic": |
| 179 | + model_class = ChatAnthropic |
| 180 | + |
| 181 | + defaults = { |
| 182 | + "model": "claude-3-5-sonnet-20241022", |
| 183 | + "temperature": 0, |
| 184 | + "timeout": None, |
| 185 | + "max_retries": 2, |
| 186 | + } |
| 187 | + |
| 188 | + case "ChatGroq": # trunk-ignore(cspell) |
| 189 | + model_class = ChatGroq # trunk-ignore(cspell) |
| 190 | + |
| 191 | + defaults = { |
| 192 | + "model": "mixtral-8x7b-32768", |
| 193 | + "temperature": 0, |
| 194 | + "timeout": None, |
| 195 | + "max_retries": 2, |
| 196 | + "max_tokens": 2049, |
| 197 | + } |
| 198 | + |
| 199 | + model_args = deep_update(defaults, config.args) |
| 200 | + model_id = model_args["model"] |
| 201 | + |
176 | 202 | case _:
|
177 | 203 | raise Exception(f"Unrecognized provider '{config.provider}'")
|
178 | 204 |
|
@@ -212,6 +238,10 @@ def challenge(k: str) -> BaseMessage:
|
212 | 238 | challenge("max_tokens")
|
213 | 239 | elif isinstance(self.llm, ChatDeepSeek):
|
214 | 240 | challenge("max_tokens")
|
| 241 | + elif isinstance(self.llm, ChatAnthropic): |
| 242 | + challenge("max_tokens") |
| 243 | + elif isinstance(self.llm, ChatGroq): # trunk-ignore(cspell) |
| 244 | + challenge("max_tokens") |
215 | 245 |
|
216 | 246 | @tracer.start_as_current_span("invoke_llm")
|
217 | 247 | def invoke(
|
|
0 commit comments