Skip to content

Commit 3ac7e01

Browse files
larcane97moon
authored andcommitted
Add VESSL AI OpenAI API-compatible model provider and LLM model (#9474)
Co-authored-by: moon <[email protected]>
1 parent d17e9e4 commit 3ac7e01

File tree

10 files changed

+289
-1
lines changed

10 files changed

+289
-1
lines changed

api/core/model_runtime/model_providers/vessl_ai/__init__.py

Whitespace-only changes.
Loading
Loading

api/core/model_runtime/model_providers/vessl_ai/llm/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from decimal import Decimal
2+
3+
from core.model_runtime.entities.common_entities import I18nObject
4+
from core.model_runtime.entities.llm_entities import LLMMode
5+
from core.model_runtime.entities.model_entities import (
6+
AIModelEntity,
7+
DefaultParameterName,
8+
FetchFrom,
9+
ModelPropertyKey,
10+
ModelType,
11+
ParameterRule,
12+
ParameterType,
13+
PriceConfig,
14+
)
15+
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
16+
17+
18+
class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
19+
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
20+
features = []
21+
22+
entity = AIModelEntity(
23+
model=model,
24+
label=I18nObject(en_US=model),
25+
model_type=ModelType.LLM,
26+
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
27+
features=features,
28+
model_properties={
29+
ModelPropertyKey.MODE: credentials.get("mode"),
30+
},
31+
parameter_rules=[
32+
ParameterRule(
33+
name=DefaultParameterName.TEMPERATURE.value,
34+
label=I18nObject(en_US="Temperature"),
35+
type=ParameterType.FLOAT,
36+
default=float(credentials.get("temperature", 0.7)),
37+
min=0,
38+
max=2,
39+
precision=2,
40+
),
41+
ParameterRule(
42+
name=DefaultParameterName.TOP_P.value,
43+
label=I18nObject(en_US="Top P"),
44+
type=ParameterType.FLOAT,
45+
default=float(credentials.get("top_p", 1)),
46+
min=0,
47+
max=1,
48+
precision=2,
49+
),
50+
ParameterRule(
51+
name=DefaultParameterName.TOP_K.value,
52+
label=I18nObject(en_US="Top K"),
53+
type=ParameterType.INT,
54+
default=int(credentials.get("top_k", 50)),
55+
min=-2147483647,
56+
max=2147483647,
57+
precision=0,
58+
),
59+
ParameterRule(
60+
name=DefaultParameterName.MAX_TOKENS.value,
61+
label=I18nObject(en_US="Max Tokens"),
62+
type=ParameterType.INT,
63+
default=512,
64+
min=1,
65+
max=int(credentials.get("max_tokens_to_sample", 4096)),
66+
),
67+
],
68+
pricing=PriceConfig(
69+
input=Decimal(credentials.get("input_price", 0)),
70+
output=Decimal(credentials.get("output_price", 0)),
71+
unit=Decimal(credentials.get("unit", 0)),
72+
currency=credentials.get("currency", "USD"),
73+
),
74+
)
75+
76+
if credentials["mode"] == "chat":
77+
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
78+
elif credentials["mode"] == "completion":
79+
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
80+
else:
81+
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
82+
83+
return entity
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import logging
2+
3+
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
class VesslAIProvider(ModelProvider):
9+
def validate_provider_credentials(self, credentials: dict) -> None:
10+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
provider: vessl_ai
2+
label:
3+
en_US: vessl_ai
4+
icon_small:
5+
en_US: icon_s_en.svg
6+
icon_large:
7+
en_US: icon_l_en.png
8+
background: "#F1EFED"
9+
help:
10+
title:
11+
en_US: How to deploy VESSL AI LLM Model Endpoint
12+
url:
13+
en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment
14+
supported_model_types:
15+
- llm
16+
configurate_methods:
17+
- customizable-model
18+
model_credential_schema:
19+
model:
20+
label:
21+
en_US: Model Name
22+
placeholder:
23+
en_US: Enter your model name
24+
credential_form_schemas:
25+
- variable: endpoint_url
26+
label:
27+
en_US: endpoint url
28+
type: text-input
29+
required: true
30+
placeholder:
31+
en_US: Enter the url of your endpoint url
32+
- variable: api_key
33+
required: true
34+
label:
35+
en_US: API Key
36+
type: secret-input
37+
placeholder:
38+
en_US: Enter your VESSL AI secret key
39+
- variable: mode
40+
show_on:
41+
- variable: __model_type
42+
value: llm
43+
label:
44+
en_US: Completion mode
45+
type: select
46+
required: false
47+
default: chat
48+
placeholder:
49+
en_US: Select completion mode
50+
options:
51+
- value: completion
52+
label:
53+
en_US: Completion
54+
- value: chat
55+
label:
56+
en_US: Chat

api/tests/integration_tests/.env.example

+6-1
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,10 @@ VOLC_EMBEDDING_ENDPOINT_ID=
8484
# 360 AI Credentials
8585
ZHINAO_API_KEY=
8686

87+
# VESSL AI Credentials
88+
VESSL_AI_MODEL_NAME=
89+
VESSL_AI_API_KEY=
90+
VESSL_AI_ENDPOINT_URL=
91+
8792
# Gitee AI Credentials
88-
GITEE_AI_API_KEY=
93+
GITEE_AI_API_KEY=

api/tests/integration_tests/model_runtime/vessl_ai/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import os
2+
from collections.abc import Generator
3+
4+
import pytest
5+
6+
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
7+
from core.model_runtime.entities.message_entities import (
8+
AssistantPromptMessage,
9+
SystemPromptMessage,
10+
UserPromptMessage,
11+
)
12+
from core.model_runtime.errors.validate import CredentialsValidateFailedError
13+
from core.model_runtime.model_providers.vessl_ai.llm.llm import VesslAILargeLanguageModel
14+
15+
16+
def test_validate_credentials():
17+
model = VesslAILargeLanguageModel()
18+
19+
with pytest.raises(CredentialsValidateFailedError):
20+
model.validate_credentials(
21+
model=os.environ.get("VESSL_AI_MODEL_NAME"),
22+
credentials={
23+
"api_key": "invalid_key",
24+
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
25+
"mode": "chat",
26+
},
27+
)
28+
29+
with pytest.raises(CredentialsValidateFailedError):
30+
model.validate_credentials(
31+
model=os.environ.get("VESSL_AI_MODEL_NAME"),
32+
credentials={
33+
"api_key": os.environ.get("VESSL_AI_API_KEY"),
34+
"endpoint_url": "http://invalid_url",
35+
"mode": "chat",
36+
},
37+
)
38+
39+
model.validate_credentials(
40+
model=os.environ.get("VESSL_AI_MODEL_NAME"),
41+
credentials={
42+
"api_key": os.environ.get("VESSL_AI_API_KEY"),
43+
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
44+
"mode": "chat",
45+
},
46+
)
47+
48+
49+
def test_invoke_model():
50+
model = VesslAILargeLanguageModel()
51+
52+
response = model.invoke(
53+
model=os.environ.get("VESSL_AI_MODEL_NAME"),
54+
credentials={
55+
"api_key": os.environ.get("VESSL_AI_API_KEY"),
56+
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
57+
"mode": "chat",
58+
},
59+
prompt_messages=[
60+
SystemPromptMessage(
61+
content="You are a helpful AI assistant.",
62+
),
63+
UserPromptMessage(content="Who are you?"),
64+
],
65+
model_parameters={
66+
"temperature": 1.0,
67+
"top_k": 2,
68+
"top_p": 0.5,
69+
},
70+
stop=["How"],
71+
stream=False,
72+
user="abc-123",
73+
)
74+
75+
assert isinstance(response, LLMResult)
76+
assert len(response.message.content) > 0
77+
78+
79+
def test_invoke_stream_model():
80+
model = VesslAILargeLanguageModel()
81+
82+
response = model.invoke(
83+
model=os.environ.get("VESSL_AI_MODEL_NAME"),
84+
credentials={
85+
"api_key": os.environ.get("VESSL_AI_API_KEY"),
86+
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
87+
"mode": "chat",
88+
},
89+
prompt_messages=[
90+
SystemPromptMessage(
91+
content="You are a helpful AI assistant.",
92+
),
93+
UserPromptMessage(content="Who are you?"),
94+
],
95+
model_parameters={
96+
"temperature": 1.0,
97+
"top_k": 2,
98+
"top_p": 0.5,
99+
},
100+
stop=["How"],
101+
stream=True,
102+
user="abc-123",
103+
)
104+
105+
assert isinstance(response, Generator)
106+
107+
for chunk in response:
108+
assert isinstance(chunk, LLMResultChunk)
109+
assert isinstance(chunk.delta, LLMResultChunkDelta)
110+
assert isinstance(chunk.delta.message, AssistantPromptMessage)
111+
112+
113+
def test_get_num_tokens():
114+
model = VesslAILargeLanguageModel()
115+
116+
num_tokens = model.get_num_tokens(
117+
model=os.environ.get("VESSL_AI_MODEL_NAME"),
118+
credentials={
119+
"api_key": os.environ.get("VESSL_AI_API_KEY"),
120+
"endpoint_url": os.environ.get("VESSL_AI_ENDPOINT_URL"),
121+
},
122+
prompt_messages=[
123+
SystemPromptMessage(
124+
content="You are a helpful AI assistant.",
125+
),
126+
UserPromptMessage(content="Hello World!"),
127+
],
128+
)
129+
130+
assert isinstance(num_tokens, int)
131+
assert num_tokens == 21

0 commit comments

Comments
 (0)