1
1
from collections .abc import Generator
2
2
from typing import Optional , Union
3
3
4
- from core .model_runtime .entities .llm_entities import LLMResult
4
+ from core .model_runtime .entities .common_entities import I18nObject
5
+ from core .model_runtime .entities .llm_entities import LLMMode , LLMResult
5
6
from core .model_runtime .entities .message_entities import PromptMessage , PromptMessageTool
7
+ from core .model_runtime .entities .model_entities import (
8
+ AIModelEntity ,
9
+ FetchFrom ,
10
+ ModelFeature ,
11
+ ModelPropertyKey ,
12
+ ModelType ,
13
+ ParameterRule ,
14
+ ParameterType ,
15
+ )
6
16
from core .model_runtime .model_providers .openai_api_compatible .llm .llm import OAIAPICompatLargeLanguageModel
7
17
8
18
@@ -29,3 +39,53 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
29
39
def _add_custom_parameters (cls , credentials : dict ) -> None :
30
40
credentials ["mode" ] = "chat"
31
41
credentials ["endpoint_url" ] = "https://api.siliconflow.cn/v1"
42
+
43
+ def get_customizable_model_schema (self , model : str , credentials : dict ) -> AIModelEntity | None :
44
+ return AIModelEntity (
45
+ model = model ,
46
+ label = I18nObject (en_US = model , zh_Hans = model ),
47
+ model_type = ModelType .LLM ,
48
+ features = [ModelFeature .TOOL_CALL , ModelFeature .MULTI_TOOL_CALL , ModelFeature .STREAM_TOOL_CALL ]
49
+ if credentials .get ("function_calling_type" ) == "tool_call"
50
+ else [],
51
+ fetch_from = FetchFrom .CUSTOMIZABLE_MODEL ,
52
+ model_properties = {
53
+ ModelPropertyKey .CONTEXT_SIZE : int (credentials .get ("context_size" , 8000 )),
54
+ ModelPropertyKey .MODE : LLMMode .CHAT .value ,
55
+ },
56
+ parameter_rules = [
57
+ ParameterRule (
58
+ name = "temperature" ,
59
+ use_template = "temperature" ,
60
+ label = I18nObject (en_US = "Temperature" , zh_Hans = "温度" ),
61
+ type = ParameterType .FLOAT ,
62
+ ),
63
+ ParameterRule (
64
+ name = "max_tokens" ,
65
+ use_template = "max_tokens" ,
66
+ default = 512 ,
67
+ min = 1 ,
68
+ max = int (credentials .get ("max_tokens" , 1024 )),
69
+ label = I18nObject (en_US = "Max Tokens" , zh_Hans = "最大标记" ),
70
+ type = ParameterType .INT ,
71
+ ),
72
+ ParameterRule (
73
+ name = "top_p" ,
74
+ use_template = "top_p" ,
75
+ label = I18nObject (en_US = "Top P" , zh_Hans = "Top P" ),
76
+ type = ParameterType .FLOAT ,
77
+ ),
78
+ ParameterRule (
79
+ name = "top_k" ,
80
+ use_template = "top_k" ,
81
+ label = I18nObject (en_US = "Top K" , zh_Hans = "Top K" ),
82
+ type = ParameterType .FLOAT ,
83
+ ),
84
+ ParameterRule (
85
+ name = "frequency_penalty" ,
86
+ use_template = "frequency_penalty" ,
87
+ label = I18nObject (en_US = "Frequency Penalty" , zh_Hans = "重复惩罚" ),
88
+ type = ParameterType .FLOAT ,
89
+ ),
90
+ ],
91
+ )
0 commit comments