Skip to content

feat: Add caching mechanism for plugin model schemas #14898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/contexts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contexts.wrapper import RecyclableContextVar

if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool
Expand All @@ -20,11 +21,19 @@
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
ContextVar("plugin_tool_providers")
)

plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))

plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers")
)

plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)

plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))

plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
41 changes: 32 additions & 9 deletions api/core/model_runtime/model_providers/__base/ai_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import decimal
import hashlib
from threading import Lock
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field

import contexts
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import (
Expand Down Expand Up @@ -139,15 +142,35 @@ def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Op
:return: model schema
"""
plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])

try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())

with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]

schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)

if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema

return schema

def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Expand Down
39 changes: 29 additions & 10 deletions api/core/model_runtime/model_providers/model_provider_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging
import os
from collections.abc import Sequence
Expand Down Expand Up @@ -206,17 +207,35 @@ def get_model_schema(
Get model schema
"""
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
model_schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])

return model_schema
try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())

with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]

schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
)

if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema

return schema

def get_models(
self,
Expand Down
Loading