Skip to content

Commit b7e5fc1

Browse files
hjlarryjiangzhijie
authored and
jiangzhijie
committed
feat: add xAI model provider (langgenius#10272)
1 parent 15df8ec commit b7e5fc1

File tree

10 files changed

+372
-0
lines changed

10 files changed

+372
-0
lines changed

api/core/model_runtime/model_providers/x/__init__.py

Whitespace-only changes.
Loading

api/core/model_runtime/model_providers/x/llm/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
model: grok-beta
2+
label:
3+
en_US: Grok beta
4+
model_type: llm
5+
features:
6+
- multi-tool-call
7+
model_properties:
8+
mode: chat
9+
context_size: 131072
10+
parameter_rules:
11+
- name: temperature
12+
label:
13+
en_US: "Temperature"
14+
zh_Hans: "采样温度"
15+
type: float
16+
default: 0.7
17+
min: 0.0
18+
max: 2.0
19+
precision: 1
20+
required: true
21+
help:
22+
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
23+
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
24+
25+
- name: top_p
26+
label:
27+
en_US: "Top P"
28+
zh_Hans: "Top P"
29+
type: float
30+
default: 0.7
31+
min: 0.0
32+
max: 1.0
33+
precision: 1
34+
required: true
35+
help:
36+
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
37+
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
38+
39+
- name: frequency_penalty
40+
use_template: frequency_penalty
41+
label:
42+
en_US: "Frequency Penalty"
43+
zh_Hans: "频率惩罚"
44+
type: float
45+
default: 0
46+
min: 0
47+
max: 2.0
48+
precision: 1
49+
required: false
50+
help:
51+
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
52+
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
53+
54+
- name: user
55+
use_template: text
56+
label:
57+
en_US: "User"
58+
zh_Hans: "用户"
59+
type: string
60+
required: false
61+
help:
62+
en_US: "Used to track and differentiate conversation requests from different users."
63+
zh_Hans: "用于追踪和区分不同用户的对话请求。"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from collections.abc import Generator
2+
from typing import Optional, Union
3+
4+
from yarl import URL
5+
6+
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
7+
from core.model_runtime.entities.message_entities import (
8+
PromptMessage,
9+
PromptMessageTool,
10+
)
11+
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
12+
13+
14+
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
15+
def _invoke(
16+
self,
17+
model: str,
18+
credentials: dict,
19+
prompt_messages: list[PromptMessage],
20+
model_parameters: dict,
21+
tools: Optional[list[PromptMessageTool]] = None,
22+
stop: Optional[list[str]] = None,
23+
stream: bool = True,
24+
user: Optional[str] = None,
25+
) -> Union[LLMResult, Generator]:
26+
self._add_custom_parameters(credentials)
27+
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
28+
29+
def validate_credentials(self, model: str, credentials: dict) -> None:
30+
self._add_custom_parameters(credentials)
31+
super().validate_credentials(model, credentials)
32+
33+
@staticmethod
34+
def _add_custom_parameters(credentials) -> None:
35+
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
36+
credentials["mode"] = LLMMode.CHAT.value
37+
credentials["function_calling_type"] = "tool_call"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import logging
2+
3+
from core.model_runtime.entities.model_entities import ModelType
4+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
5+
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class XAIProvider(ModelProvider):
11+
def validate_provider_credentials(self, credentials: dict) -> None:
12+
"""
13+
Validate provider credentials
14+
if validate failed, raise exception
15+
16+
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
17+
"""
18+
try:
19+
model_instance = self.get_model_instance(ModelType.LLM)
20+
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
21+
except CredentialsValidateFailedError as ex:
22+
raise ex
23+
except Exception as ex:
24+
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
25+
raise ex
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
provider: x
2+
label:
3+
en_US: xAI
4+
description:
5+
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
6+
icon_small:
7+
en_US: x-ai-logo.svg
8+
icon_large:
9+
en_US: x-ai-logo.svg
10+
help:
11+
title:
12+
en_US: Get your token from xAI
13+
zh_Hans: 从 xAI 获取 token
14+
url:
15+
en_US: https://x.ai/api
16+
supported_model_types:
17+
- llm
18+
configurate_methods:
19+
- predefined-model
20+
provider_credential_schema:
21+
credential_form_schemas:
22+
- variable: api_key
23+
label:
24+
en_US: API Key
25+
type: secret-input
26+
required: true
27+
placeholder:
28+
zh_Hans: 在此输入您的 API Key
29+
en_US: Enter your API Key
30+
- variable: endpoint_url
31+
label:
32+
en_US: API Base
33+
type: text-input
34+
required: false
35+
default: https://api.x.ai/v1
36+
placeholder:
37+
zh_Hans: 在此输入您的 API Base
38+
en_US: Enter your API Base

api/tests/integration_tests/.env.example

+4
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,7 @@ GPUSTACK_API_KEY=
9595

9696
# Gitee AI Credentials
9797
GITEE_AI_API_KEY=
98+
99+
# xAI Credentials
100+
XAI_API_KEY=
101+
XAI_API_BASE=

api/tests/integration_tests/model_runtime/x/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import os
2+
from collections.abc import Generator
3+
4+
import pytest
5+
6+
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
7+
from core.model_runtime.entities.message_entities import (
8+
AssistantPromptMessage,
9+
PromptMessageTool,
10+
SystemPromptMessage,
11+
UserPromptMessage,
12+
)
13+
from core.model_runtime.entities.model_entities import AIModelEntity
14+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
15+
from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel
16+
17+
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
18+
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
19+
20+
21+
def test_predefined_models():
22+
model = XAILargeLanguageModel()
23+
model_schemas = model.predefined_models()
24+
25+
assert len(model_schemas) >= 1
26+
assert isinstance(model_schemas[0], AIModelEntity)
27+
28+
29+
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
30+
def test_validate_credentials_for_chat_model(setup_openai_mock):
31+
model = XAILargeLanguageModel()
32+
33+
with pytest.raises(CredentialsValidateFailedError):
34+
# model name to gpt-3.5-turbo because of mocking
35+
model.validate_credentials(
36+
model="gpt-3.5-turbo",
37+
credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
38+
)
39+
40+
model.validate_credentials(
41+
model="grok-beta",
42+
credentials={
43+
"api_key": os.environ.get("XAI_API_KEY"),
44+
"endpoint_url": os.environ.get("XAI_API_BASE"),
45+
"mode": "chat",
46+
},
47+
)
48+
49+
50+
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
51+
def test_invoke_chat_model(setup_openai_mock):
52+
model = XAILargeLanguageModel()
53+
54+
result = model.invoke(
55+
model="grok-beta",
56+
credentials={
57+
"api_key": os.environ.get("XAI_API_KEY"),
58+
"endpoint_url": os.environ.get("XAI_API_BASE"),
59+
"mode": "chat",
60+
},
61+
prompt_messages=[
62+
SystemPromptMessage(
63+
content="You are a helpful AI assistant.",
64+
),
65+
UserPromptMessage(content="Hello World!"),
66+
],
67+
model_parameters={
68+
"temperature": 0.0,
69+
"top_p": 1.0,
70+
"presence_penalty": 0.0,
71+
"frequency_penalty": 0.0,
72+
"max_tokens": 10,
73+
},
74+
stop=["How"],
75+
stream=False,
76+
user="foo",
77+
)
78+
79+
assert isinstance(result, LLMResult)
80+
assert len(result.message.content) > 0
81+
82+
83+
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
84+
def test_invoke_chat_model_with_tools(setup_openai_mock):
85+
model = XAILargeLanguageModel()
86+
87+
result = model.invoke(
88+
model="grok-beta",
89+
credentials={
90+
"api_key": os.environ.get("XAI_API_KEY"),
91+
"endpoint_url": os.environ.get("XAI_API_BASE"),
92+
"mode": "chat",
93+
},
94+
prompt_messages=[
95+
SystemPromptMessage(
96+
content="You are a helpful AI assistant.",
97+
),
98+
UserPromptMessage(
99+
content="what's the weather today in London?",
100+
),
101+
],
102+
model_parameters={"temperature": 0.0, "max_tokens": 100},
103+
tools=[
104+
PromptMessageTool(
105+
name="get_weather",
106+
description="Determine weather in my location",
107+
parameters={
108+
"type": "object",
109+
"properties": {
110+
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
111+
"unit": {"type": "string", "enum": ["c", "f"]},
112+
},
113+
"required": ["location"],
114+
},
115+
),
116+
PromptMessageTool(
117+
name="get_stock_price",
118+
description="Get the current stock price",
119+
parameters={
120+
"type": "object",
121+
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
122+
"required": ["symbol"],
123+
},
124+
),
125+
],
126+
stream=False,
127+
user="foo",
128+
)
129+
130+
assert isinstance(result, LLMResult)
131+
assert isinstance(result.message, AssistantPromptMessage)
132+
133+
134+
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
135+
def test_invoke_stream_chat_model(setup_openai_mock):
136+
model = XAILargeLanguageModel()
137+
138+
result = model.invoke(
139+
model="grok-beta",
140+
credentials={
141+
"api_key": os.environ.get("XAI_API_KEY"),
142+
"endpoint_url": os.environ.get("XAI_API_BASE"),
143+
"mode": "chat",
144+
},
145+
prompt_messages=[
146+
SystemPromptMessage(
147+
content="You are a helpful AI assistant.",
148+
),
149+
UserPromptMessage(content="Hello World!"),
150+
],
151+
model_parameters={"temperature": 0.0, "max_tokens": 100},
152+
stream=True,
153+
user="foo",
154+
)
155+
156+
assert isinstance(result, Generator)
157+
158+
for chunk in result:
159+
assert isinstance(chunk, LLMResultChunk)
160+
assert isinstance(chunk.delta, LLMResultChunkDelta)
161+
assert isinstance(chunk.delta.message, AssistantPromptMessage)
162+
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
163+
if chunk.delta.finish_reason is not None:
164+
assert chunk.delta.usage is not None
165+
assert chunk.delta.usage.completion_tokens > 0
166+
167+
168+
def test_get_num_tokens():
169+
model = XAILargeLanguageModel()
170+
171+
num_tokens = model.get_num_tokens(
172+
model="grok-beta",
173+
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
174+
prompt_messages=[UserPromptMessage(content="Hello World!")],
175+
)
176+
177+
assert num_tokens == 10
178+
179+
num_tokens = model.get_num_tokens(
180+
model="grok-beta",
181+
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
182+
prompt_messages=[
183+
SystemPromptMessage(
184+
content="You are a helpful AI assistant.",
185+
),
186+
UserPromptMessage(content="Hello World!"),
187+
],
188+
tools=[
189+
PromptMessageTool(
190+
name="get_weather",
191+
description="Determine weather in my location",
192+
parameters={
193+
"type": "object",
194+
"properties": {
195+
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
196+
"unit": {"type": "string", "enum": ["c", "f"]},
197+
},
198+
"required": ["location"],
199+
},
200+
),
201+
],
202+
)
203+
204+
assert num_tokens == 77

0 commit comments

Comments
 (0)