Skip to content

Commit d42e2c7

Browse files
committed
Added Anthropic and Groq support
Signed-off-by: devjpt23 <[email protected]>
1 parent f5814b5 commit d42e2c7

File tree

4 files changed

+35
-1
lines changed

4 files changed

+35
-1
lines changed

kai/kai_config.py

+2
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,13 @@ class KaiConfigIncidentStore(BaseModel):
150150

151151

152152
class SupportedModelProviders(StrEnum):
153+
CHAT_ANTHROPIC = "ChatAnthropic"
153154
CHAT_OLLAMA = "ChatOllama"
154155
CHAT_OPENAI = "ChatOpenAI"
155156
CHAT_BEDROCK = "ChatBedrock"
156157
FAKE_LIST_CHAT_MODEL = "FakeListChatModel"
157158
CHAT_GOOGLE_GENERATIVE_AI = "ChatGoogleGenerativeAI"
159+
CHAT_GROQ = "ChatGroq" # trunk-ignore(cspell)
158160
AZURE_CHAT_OPENAI = "AzureChatOpenAI"
159161
CHAT_DEEP_SEEK = "ChatDeepSeek"
160162

kai/llm_interfacing/model_provider.py

+30
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
from typing import Any, Optional
55

6+
from langchain_anthropic import ChatAnthropic
67
from langchain_aws import ChatBedrock
78
from langchain_community.chat_models.fake import FakeListChatModel
89
from langchain_core.language_models.base import LanguageModelInput
@@ -11,6 +12,7 @@
1112
from langchain_core.runnables import ConfigurableField, RunnableConfig
1213
from langchain_deepseek import ChatDeepSeek
1314
from langchain_google_genai import ChatGoogleGenerativeAI
15+
from langchain_groq import ChatGroq # trunk-ignore(cspell)
1416
from langchain_ollama import ChatOllama
1517
from langchain_openai import AzureChatOpenAI, ChatOpenAI
1618
from opentelemetry import trace
@@ -173,6 +175,30 @@ def _get_request_payload(
173175
model_args = deep_update(defaults, config.args)
174176
model_id = model_args["model"]
175177

178+
case "ChatAnthropic":
179+
model_class = ChatAnthropic
180+
181+
defaults = {
182+
"model": "claude-3-5-sonnet-20241022",
183+
"temperature": 0,
184+
"timeout": None,
185+
"max_retries": 2,
186+
}
187+
188+
case "ChatGroq": # trunk-ignore(cspell)
189+
model_class = ChatGroq # trunk-ignore(cspell)
190+
191+
defaults = {
192+
"model": "mixtral-8x7b-32768",
193+
"temperature": 0,
194+
"timeout": None,
195+
"max_retries": 2,
196+
"max_tokens": 2049,
197+
}
198+
199+
model_args = deep_update(defaults, config.args)
200+
model_id = model_args["model"]
201+
176202
case _:
177203
raise Exception(f"Unrecognized provider '{config.provider}'")
178204

@@ -212,6 +238,10 @@ def challenge(k: str) -> BaseMessage:
212238
challenge("max_tokens")
213239
elif isinstance(self.llm, ChatDeepSeek):
214240
challenge("max_tokens")
241+
elif isinstance(self.llm, ChatAnthropic):
242+
challenge("max_tokens")
243+
elif isinstance(self.llm, ChatGroq): # trunk-ignore(cspell)
244+
challenge("max_tokens")
215245

216246
@tracer.start_as_current_span("invoke_llm")
217247
def invoke(

kai/rpc_server/server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ class GetCodeplanAgentSolutionParams(BaseModel):
464464
max_depth: Optional[int] = None
465465
max_priority: Optional[int] = None
466466

467-
chat_token: str
467+
chat_token: Optional[str] = None
468468

469469

470470
class GetCodeplanAgentSolutionResult(BaseModel):

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ dependencies = [
2727
"python-dateutil==2.8.2",
2828
"Jinja2==3.1.4",
2929
"langchain==0.3.17",
30+
"langchain-anthropic==0.3.7",
3031
"langchain-community==0.3.1",
3132
"langchain-openai==0.3.3",
3233
"langchain-ollama==0.2.3",
3334
"langchain-google-genai==2.0.9",
35+
"langchain-groq==0.2.4",
3436
"langchain-aws==0.2.11",
3537
"langchain-experimental==0.3.2",
3638
"langchain-deepseek-official==0.1.0",

0 commit comments

Comments
 (0)