Skip to content

Commit 0a97ed4

Browse files
authored
Merge branch 'main' into couchintegration
2 parents 348e32a + f4ce082 commit 0a97ed4

File tree

23 files changed

+309
-97
lines changed

23 files changed

+309
-97
lines changed

api/controllers/console/workspace/model_providers.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,12 @@ class ModelProviderIconApi(Resource):
126126
Get model provider icon
127127
"""
128128

129-
@setup_required
130-
@login_required
131-
@account_initialization_required
132129
def get(self, provider: str, icon_type: str, lang: str):
133130
model_provider_service = ModelProviderService()
134131
icon, mimetype = model_provider_service.get_model_provider_icon(
135-
provider=provider, icon_type=icon_type, lang=lang
132+
provider=provider,
133+
icon_type=icon_type,
134+
lang=lang,
136135
)
137136

138137
return send_file(io.BytesIO(icon), mimetype=mimetype)

api/core/model_runtime/model_providers/siliconflow/llm/llm.py

+61-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
from collections.abc import Generator
22
from typing import Optional, Union
33

4-
from core.model_runtime.entities.llm_entities import LLMResult
4+
from core.model_runtime.entities.common_entities import I18nObject
5+
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
56
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
7+
from core.model_runtime.entities.model_entities import (
8+
AIModelEntity,
9+
FetchFrom,
10+
ModelFeature,
11+
ModelPropertyKey,
12+
ModelType,
13+
ParameterRule,
14+
ParameterType,
15+
)
616
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
717

818

@@ -29,3 +39,53 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
2939
def _add_custom_parameters(cls, credentials: dict) -> None:
3040
credentials["mode"] = "chat"
3141
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
42+
43+
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
44+
return AIModelEntity(
45+
model=model,
46+
label=I18nObject(en_US=model, zh_Hans=model),
47+
model_type=ModelType.LLM,
48+
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
49+
if credentials.get("function_calling_type") == "tool_call"
50+
else [],
51+
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
52+
model_properties={
53+
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)),
54+
ModelPropertyKey.MODE: LLMMode.CHAT.value,
55+
},
56+
parameter_rules=[
57+
ParameterRule(
58+
name="temperature",
59+
use_template="temperature",
60+
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
61+
type=ParameterType.FLOAT,
62+
),
63+
ParameterRule(
64+
name="max_tokens",
65+
use_template="max_tokens",
66+
default=512,
67+
min=1,
68+
max=int(credentials.get("max_tokens", 1024)),
69+
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
70+
type=ParameterType.INT,
71+
),
72+
ParameterRule(
73+
name="top_p",
74+
use_template="top_p",
75+
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
76+
type=ParameterType.FLOAT,
77+
),
78+
ParameterRule(
79+
name="top_k",
80+
use_template="top_k",
81+
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
82+
type=ParameterType.FLOAT,
83+
),
84+
ParameterRule(
85+
name="frequency_penalty",
86+
use_template="frequency_penalty",
87+
label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"),
88+
type=ParameterType.FLOAT,
89+
),
90+
],
91+
)

api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml

+55
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ supported_model_types:
2020
- speech2text
2121
configurate_methods:
2222
- predefined-model
23+
- customizable-model
2324
provider_credential_schema:
2425
credential_form_schemas:
2526
- variable: api_key
@@ -30,3 +31,57 @@ provider_credential_schema:
3031
placeholder:
3132
zh_Hans: 在此输入您的 API Key
3233
en_US: Enter your API Key
34+
model_credential_schema:
35+
model:
36+
label:
37+
en_US: Model Name
38+
zh_Hans: 模型名称
39+
placeholder:
40+
en_US: Enter your model name
41+
zh_Hans: 输入模型名称
42+
credential_form_schemas:
43+
- variable: api_key
44+
label:
45+
en_US: API Key
46+
type: secret-input
47+
required: true
48+
placeholder:
49+
zh_Hans: 在此输入您的 API Key
50+
en_US: Enter your API Key
51+
- variable: context_size
52+
label:
53+
zh_Hans: 模型上下文长度
54+
en_US: Model context size
55+
required: true
56+
type: text-input
57+
default: '4096'
58+
placeholder:
59+
zh_Hans: 在此输入您的模型上下文长度
60+
en_US: Enter your Model context size
61+
- variable: max_tokens
62+
label:
63+
zh_Hans: 最大 token 上限
64+
en_US: Upper bound for max tokens
65+
default: '4096'
66+
type: text-input
67+
show_on:
68+
- variable: __model_type
69+
value: llm
70+
- variable: function_calling_type
71+
label:
72+
en_US: Function calling
73+
type: select
74+
required: false
75+
default: no_call
76+
options:
77+
- value: no_call
78+
label:
79+
en_US: Not Support
80+
zh_Hans: 不支持
81+
- value: function_call
82+
label:
83+
en_US: Support
84+
zh_Hans: 支持
85+
show_on:
86+
- variable: __model_type
87+
value: llm

api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import math
34
from typing import Any, Optional
45
from urllib.parse import urlparse
56

@@ -112,7 +113,8 @@ def delete(self) -> None:
112113

113114
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
114115
top_k = kwargs.get("top_k", 10)
115-
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
116+
num_candidates = math.ceil(top_k * 1.5)
117+
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
116118

117119
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
118120

api/services/dataset_service.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from flask_login import current_user
1010
from sqlalchemy import func
11+
from werkzeug.exceptions import NotFound
1112

1213
from configs import dify_config
1314
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
@@ -975,6 +976,8 @@ def update_document_with_dataset_id(
975976
):
976977
DatasetService.check_dataset_model_setting(dataset)
977978
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
979+
if document is None:
980+
raise NotFound("Document not found")
978981
if document.display_status != "available":
979982
raise ValueError("Document is not available")
980983
# update document name

api/services/enterprise/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@ class EnterpriseRequest:
77
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
88
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
99

10+
proxies = {
11+
"http": None,
12+
"https": None,
13+
}
14+
1015
@classmethod
1116
def send_request(cls, method, endpoint, json=None, params=None):
1217
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
1318

1419
url = f"{cls.base_url}{endpoint}"
15-
response = requests.request(method, url, json=json, params=params, headers=headers)
20+
response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies)
1621

1722
return response.json()

api/tests/artifact_tests/dependencies/test_dependencies_sorted.py

+14-26
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,23 @@
22

33
import toml
44

5-
ALL_DEPENDENCY_GROUP_NAMES = [
6-
# default main group
7-
"",
8-
# required groups
9-
"indirect",
10-
"storage",
11-
"tools",
12-
"vdb",
13-
# optional groups
14-
"dev",
15-
"lint",
16-
]
17-
185

196
def load_api_poetry_configs() -> dict[str, Any]:
207
pyproject_toml = toml.load("api/pyproject.toml")
21-
return pyproject_toml.get("tool").get("poetry")
8+
return pyproject_toml["tool"]["poetry"]
229

2310

24-
def load_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]:
25-
poetry_configs = load_api_poetry_configs()
26-
group_name_to_dependencies = {
27-
group_name: (poetry_configs.get("group").get(group_name) if group_name else poetry_configs).get("dependencies")
28-
for group_name in ALL_DEPENDENCY_GROUP_NAMES
29-
}
30-
return group_name_to_dependencies
11+
def load_all_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]:
12+
configs = load_api_poetry_configs()
13+
configs_by_group = {"main": configs}
14+
for group_name in configs["group"]:
15+
configs_by_group[group_name] = configs["group"][group_name]
16+
dependencies_by_group = {group_name: base["dependencies"] for group_name, base in configs_by_group.items()}
17+
return dependencies_by_group
3118

3219

3320
def test_group_dependencies_sorted():
34-
for group_name, dependencies in load_dependency_groups().items():
21+
for group_name, dependencies in load_all_dependency_groups().items():
3522
dependency_names = list(dependencies.keys())
3623
expected_dependency_names = sorted(set(dependency_names))
3724
section = f"tool.poetry.group.{group_name}.dependencies" if group_name else "tool.poetry.dependencies"
@@ -42,17 +29,18 @@ def test_group_dependencies_sorted():
4229

4330

4431
def test_group_dependencies_version_operator():
45-
for group_name, dependencies in load_dependency_groups().items():
32+
for group_name, dependencies in load_all_dependency_groups().items():
4633
for dependency_name, specification in dependencies.items():
47-
version_spec = specification if isinstance(specification, str) else specification.get("version")
34+
version_spec = specification if isinstance(specification, str) else specification["version"]
4835
assert not version_spec.startswith("^"), (
49-
f"'^' is not allowed in dependency version," f" but found in '{dependency_name} = {version_spec}'"
36+
f"Please replace '{dependency_name} = {version_spec}' with '{dependency_name} = ~{version_spec[1:]}' "
37+
f"'^' operator is too wide and not allowed in the version specification."
5038
)
5139

5240

5341
def test_duplicated_dependency_crossing_groups():
5442
all_dependency_names: list[str] = []
55-
for dependencies in load_dependency_groups().values():
43+
for dependencies in load_all_dependency_groups().values():
5644
dependency_names = list(dependencies.keys())
5745
all_dependency_names.extend(dependency_names)
5846
expected_all_dependency_names = set(all_dependency_names)

docker/.env.example

+2
Original file line numberDiff line numberDiff line change
@@ -804,3 +804,5 @@ POSITION_TOOL_EXCLUDES=
804804
POSITION_PROVIDER_PINS=
805805
POSITION_PROVIDER_INCLUDES=
806806
POSITION_PROVIDER_EXCLUDES=
807+
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
808+
CSP_WHITELIST=

0 commit comments

Comments
 (0)