Skip to content

Commit 6e03c10

Browse files
gitlawriamjoel
authored andcommitted
feat: add gpustack model provider (#10158)
1 parent 8f14c42 commit 6e03c10

File tree

17 files changed

+705
-1
lines changed

17 files changed

+705
-1
lines changed
Loading
Loading
Loading
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import logging
2+
3+
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
class GPUStackProvider(ModelProvider):
9+
def validate_provider_credentials(self, credentials: dict) -> None:
10+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
provider: gpustack
2+
label:
3+
en_US: GPUStack
4+
icon_small:
5+
en_US: icon_s_en.png
6+
icon_large:
7+
en_US: icon_l_en.png
8+
supported_model_types:
9+
- llm
10+
- text-embedding
11+
- rerank
12+
configurate_methods:
13+
- customizable-model
14+
model_credential_schema:
15+
model:
16+
label:
17+
en_US: Model Name
18+
zh_Hans: 模型名称
19+
placeholder:
20+
en_US: Enter your model name
21+
zh_Hans: 输入模型名称
22+
credential_form_schemas:
23+
- variable: endpoint_url
24+
label:
25+
zh_Hans: 服务器地址
26+
en_US: Server URL
27+
type: text-input
28+
required: true
29+
placeholder:
30+
zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100
31+
en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100
32+
- variable: api_key
33+
label:
34+
en_US: API Key
35+
type: secret-input
36+
required: true
37+
placeholder:
38+
zh_Hans: 输入您的 API Key
39+
en_US: Enter your API Key
40+
- variable: mode
41+
show_on:
42+
- variable: __model_type
43+
value: llm
44+
label:
45+
en_US: Completion mode
46+
type: select
47+
required: false
48+
default: chat
49+
placeholder:
50+
zh_Hans: 选择补全类型
51+
en_US: Select completion type
52+
options:
53+
- value: completion
54+
label:
55+
en_US: Completion
56+
zh_Hans: 补全
57+
- value: chat
58+
label:
59+
en_US: Chat
60+
zh_Hans: 对话
61+
- variable: context_size
62+
label:
63+
zh_Hans: 模型上下文长度
64+
en_US: Model context size
65+
required: true
66+
type: text-input
67+
default: "8192"
68+
placeholder:
69+
zh_Hans: 输入您的模型上下文长度
70+
en_US: Enter your Model context size
71+
- variable: max_tokens_to_sample
72+
label:
73+
zh_Hans: 最大 token 上限
74+
en_US: Upper bound for max tokens
75+
show_on:
76+
- variable: __model_type
77+
value: llm
78+
default: "8192"
79+
type: text-input
80+
- variable: function_calling_type
81+
show_on:
82+
- variable: __model_type
83+
value: llm
84+
label:
85+
en_US: Function calling
86+
type: select
87+
required: false
88+
default: no_call
89+
options:
90+
- value: function_call
91+
label:
92+
en_US: Function Call
93+
zh_Hans: Function Call
94+
- value: tool_call
95+
label:
96+
en_US: Tool Call
97+
zh_Hans: Tool Call
98+
- value: no_call
99+
label:
100+
en_US: Not Support
101+
zh_Hans: 不支持
102+
- variable: vision_support
103+
show_on:
104+
- variable: __model_type
105+
value: llm
106+
label:
107+
zh_Hans: Vision 支持
108+
en_US: Vision Support
109+
type: select
110+
required: false
111+
default: no_support
112+
options:
113+
- value: support
114+
label:
115+
en_US: Support
116+
zh_Hans: 支持
117+
- value: no_support
118+
label:
119+
en_US: Not Support
120+
zh_Hans: 不支持

api/core/model_runtime/model_providers/gpustack/llm/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from collections.abc import Generator
2+
3+
from yarl import URL
4+
5+
from core.model_runtime.entities.llm_entities import LLMResult
6+
from core.model_runtime.entities.message_entities import (
7+
PromptMessage,
8+
PromptMessageTool,
9+
)
10+
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
11+
OAIAPICompatLargeLanguageModel,
12+
)
13+
14+
15+
class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
16+
def _invoke(
17+
self,
18+
model: str,
19+
credentials: dict,
20+
prompt_messages: list[PromptMessage],
21+
model_parameters: dict,
22+
tools: list[PromptMessageTool] | None = None,
23+
stop: list[str] | None = None,
24+
stream: bool = True,
25+
user: str | None = None,
26+
) -> LLMResult | Generator:
27+
return super()._invoke(
28+
model,
29+
credentials,
30+
prompt_messages,
31+
model_parameters,
32+
tools,
33+
stop,
34+
stream,
35+
user,
36+
)
37+
38+
def validate_credentials(self, model: str, credentials: dict) -> None:
39+
self._add_custom_parameters(credentials)
40+
super().validate_credentials(model, credentials)
41+
42+
@staticmethod
43+
def _add_custom_parameters(credentials: dict) -> None:
44+
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
45+
credentials["mode"] = "chat"

api/core/model_runtime/model_providers/gpustack/rerank/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from json import dumps
2+
from typing import Optional
3+
4+
import httpx
5+
from requests import post
6+
from yarl import URL
7+
8+
from core.model_runtime.entities.common_entities import I18nObject
9+
from core.model_runtime.entities.model_entities import (
10+
AIModelEntity,
11+
FetchFrom,
12+
ModelPropertyKey,
13+
ModelType,
14+
)
15+
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
16+
from core.model_runtime.errors.invoke import (
17+
InvokeAuthorizationError,
18+
InvokeBadRequestError,
19+
InvokeConnectionError,
20+
InvokeError,
21+
InvokeRateLimitError,
22+
InvokeServerUnavailableError,
23+
)
24+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
25+
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
26+
27+
28+
class GPUStackRerankModel(RerankModel):
29+
"""
30+
Model class for GPUStack rerank model.
31+
"""
32+
33+
def _invoke(
34+
self,
35+
model: str,
36+
credentials: dict,
37+
query: str,
38+
docs: list[str],
39+
score_threshold: Optional[float] = None,
40+
top_n: Optional[int] = None,
41+
user: Optional[str] = None,
42+
) -> RerankResult:
43+
"""
44+
Invoke rerank model
45+
46+
:param model: model name
47+
:param credentials: model credentials
48+
:param query: search query
49+
:param docs: docs for reranking
50+
:param score_threshold: score threshold
51+
:param top_n: top n documents to return
52+
:param user: unique user id
53+
:return: rerank result
54+
"""
55+
if len(docs) == 0:
56+
return RerankResult(model=model, docs=[])
57+
58+
endpoint_url = credentials["endpoint_url"]
59+
headers = {
60+
"Authorization": f"Bearer {credentials.get('api_key')}",
61+
"Content-Type": "application/json",
62+
}
63+
64+
data = {"model": model, "query": query, "documents": docs, "top_n": top_n}
65+
66+
try:
67+
response = post(
68+
str(URL(endpoint_url) / "v1" / "rerank"),
69+
headers=headers,
70+
data=dumps(data),
71+
timeout=10,
72+
)
73+
response.raise_for_status()
74+
results = response.json()
75+
76+
rerank_documents = []
77+
for result in results["results"]:
78+
index = result["index"]
79+
if "document" in result:
80+
text = result["document"]["text"]
81+
else:
82+
text = docs[index]
83+
84+
rerank_document = RerankDocument(
85+
index=index,
86+
text=text,
87+
score=result["relevance_score"],
88+
)
89+
90+
if score_threshold is None or result["relevance_score"] >= score_threshold:
91+
rerank_documents.append(rerank_document)
92+
93+
return RerankResult(model=model, docs=rerank_documents)
94+
except httpx.HTTPStatusError as e:
95+
raise InvokeServerUnavailableError(str(e))
96+
97+
def validate_credentials(self, model: str, credentials: dict) -> None:
98+
"""
99+
Validate model credentials
100+
101+
:param model: model name
102+
:param credentials: model credentials
103+
:return:
104+
"""
105+
try:
106+
self._invoke(
107+
model=model,
108+
credentials=credentials,
109+
query="What is the capital of the United States?",
110+
docs=[
111+
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
112+
"Census, Carson City had a population of 55,274.",
113+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
114+
"are a political division controlled by the United States. Its capital is Saipan.",
115+
],
116+
score_threshold=0.8,
117+
)
118+
except Exception as ex:
119+
raise CredentialsValidateFailedError(str(ex))
120+
121+
@property
122+
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
123+
"""
124+
Map model invoke error to unified error
125+
"""
126+
return {
127+
InvokeConnectionError: [httpx.ConnectError],
128+
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
129+
InvokeRateLimitError: [],
130+
InvokeAuthorizationError: [httpx.HTTPStatusError],
131+
InvokeBadRequestError: [httpx.RequestError],
132+
}
133+
134+
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
135+
"""
136+
generate custom model entities from credentials
137+
"""
138+
entity = AIModelEntity(
139+
model=model,
140+
label=I18nObject(en_US=model),
141+
model_type=ModelType.RERANK,
142+
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
143+
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
144+
)
145+
146+
return entity

api/core/model_runtime/model_providers/gpustack/text_embedding/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Optional
2+
3+
from yarl import URL
4+
5+
from core.entities.embedding_type import EmbeddingInputType
6+
from core.model_runtime.entities.text_embedding_entities import (
7+
TextEmbeddingResult,
8+
)
9+
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
10+
OAICompatEmbeddingModel,
11+
)
12+
13+
14+
class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
15+
"""
16+
Model class for GPUStack text embedding model.
17+
"""
18+
19+
def _invoke(
20+
self,
21+
model: str,
22+
credentials: dict,
23+
texts: list[str],
24+
user: Optional[str] = None,
25+
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
26+
) -> TextEmbeddingResult:
27+
return super()._invoke(model, credentials, texts, user, input_type)
28+
29+
def validate_credentials(self, model: str, credentials: dict) -> None:
30+
self._add_custom_parameters(credentials)
31+
super().validate_credentials(model, credentials)
32+
33+
@staticmethod
34+
def _add_custom_parameters(credentials: dict) -> None:
35+
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")

0 commit comments

Comments
 (0)