Skip to content

feat: add xAI model provider #10272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
63 changes: 63 additions & 0 deletions api/core/model_runtime/model_providers/x/llm/grok-beta.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
model: grok-beta
label:
en_US: Grok beta
model_type: llm
features:
- multi-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 2.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"

- name: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"

- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: 0
max: 2.0
precision: 1
required: false
help:
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"

- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"
37 changes: 37 additions & 0 deletions api/core/model_runtime/model_providers/x/llm/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from collections.abc import Generator
from typing import Optional, Union

from yarl import URL

from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel


class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)

@staticmethod
def _add_custom_parameters(credentials) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
credentials["mode"] = LLMMode.CHAT.value
credentials["function_calling_type"] = "tool_call"
25 changes: 25 additions & 0 deletions api/core/model_runtime/model_providers/x/x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class XAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception

:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex
38 changes: 38 additions & 0 deletions api/core/model_runtime/model_providers/x/x.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
provider: x
label:
en_US: xAI
description:
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
icon_small:
en_US: x-ai-logo.svg
icon_large:
en_US: x-ai-logo.svg
help:
title:
en_US: Get your token from xAI
zh_Hans: 从 xAI 获取 token
url:
en_US: https://x.ai/api
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: endpoint_url
label:
en_US: API Base
type: text-input
required: false
default: https://api.x.ai/v1
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base
4 changes: 4 additions & 0 deletions api/tests/integration_tests/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,7 @@ GPUSTACK_API_KEY=

# Gitee AI Credentials
GITEE_AI_API_KEY=

# xAI Credentials
XAI_API_KEY=
XAI_API_BASE=
Empty file.
204 changes: 204 additions & 0 deletions api/tests/integration_tests/model_runtime/x/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import os
from collections.abc import Generator

import pytest

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel

"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock


def test_predefined_models():
model = XAILargeLanguageModel()
model_schemas = model.predefined_models()

assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)


@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
model = XAILargeLanguageModel()

with pytest.raises(CredentialsValidateFailedError):
# model name to gpt-3.5-turbo because of mocking
model.validate_credentials(
model="gpt-3.5-turbo",
credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
)

model.validate_credentials(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
)


@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model(setup_openai_mock):
model = XAILargeLanguageModel()

result = model.invoke(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={
"temperature": 0.0,
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"max_tokens": 10,
},
stop=["How"],
stream=False,
user="foo",
)

assert isinstance(result, LLMResult)
assert len(result.message.content) > 0


@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model_with_tools(setup_openai_mock):
model = XAILargeLanguageModel()

result = model.invoke(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content="what's the weather today in London?",
),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
tools=[
PromptMessageTool(
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
),
PromptMessageTool(
name="get_stock_price",
description="Get the current stock price",
parameters={
"type": "object",
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
"required": ["symbol"],
},
),
],
stream=False,
user="foo",
)

assert isinstance(result, LLMResult)
assert isinstance(result.message, AssistantPromptMessage)


@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_chat_model(setup_openai_mock):
model = XAILargeLanguageModel()

result = model.invoke(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="foo",
)

assert isinstance(result, Generator)

for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.finish_reason is not None:
assert chunk.delta.usage is not None
assert chunk.delta.usage.completion_tokens > 0


def test_get_num_tokens():
model = XAILargeLanguageModel()

num_tokens = model.get_num_tokens(
model="grok-beta",
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)

assert num_tokens == 10

num_tokens = model.get_num_tokens(
model="grok-beta",
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
),
],
)

assert num_tokens == 77