Skip to content

refactor: optimize database usage #12071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 180 additions & 172 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion api/core/app/apps/message_based_app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def _handle_response(
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=stream,
)

Expand Down
192 changes: 106 additions & 86 deletions api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections.abc import Generator
from typing import Any, Optional, Union

from sqlalchemy.orm import Session

from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager
Expand Down Expand Up @@ -50,6 +52,7 @@
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.model import EndUser
from models.workflow import (
Workflow,
Expand All @@ -68,8 +71,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""

_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
Expand All @@ -83,25 +84,27 @@ def __init__(
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: user
:param stream: is streamed
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
super().__init__(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
stream=stream,
)

if isinstance(self._user, EndUser):
user_id = self._user.session_id
if isinstance(user, EndUser):
self._user_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
else:
user_id = self._user.id
raise ValueError(f"Invalid user type: {type(user)}")

self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict

self._workflow = workflow
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.USER_ID: self._user_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
Expand All @@ -115,10 +118,6 @@ def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStr
Process generate task pipeline.
:return:
"""
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()

generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
Expand Down Expand Up @@ -185,7 +184,7 @@ def _wrapper_process_stream_response(
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
features_dict = self._workflow_features_dict

if (
features_dict.get("text_to_speech")
Expand Down Expand Up @@ -242,18 +241,26 @@ def _process_stream_response(
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
err = self._handle_error(event=event)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state

# init workflow run
workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
with Session(db.engine) as session:
# init workflow run
workflow_run = self._handle_workflow_run_start(
session=session,
workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
start_resp = self._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield start_resp
elif isinstance(
event,
QueueNodeRetryEvent,
Expand Down Expand Up @@ -350,72 +357,87 @@ def _process_stream_response(
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")

workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_success(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
)

# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)

workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise ValueError("workflow run not initialized.")

if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")

workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_partial_success(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)

# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)

workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()

yield workflow_finish_resp
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise ValueError("workflow run not initialized.")

if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
session=session,
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)

# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)

workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
Expand All @@ -435,7 +457,7 @@ def _process_stream_response(
if tts_publisher:
tts_publisher.publish(None)

def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
:return:
Expand All @@ -457,12 +479,10 @@ def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
workflow_app_log.created_by = self._user.id
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id

db.session.add(workflow_app_log)
db.session.commit()
db.session.close()
session.add(workflow_app_log)

def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
Expand Down
36 changes: 15 additions & 21 deletions api/core/app/task_pipeline/based_generate_task_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import time
from typing import Optional, Union
from typing import Optional

from sqlalchemy import select
from sqlalchemy.orm import Session

from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
Expand All @@ -17,9 +20,7 @@
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser, Message
from models.model import Message

logger = logging.getLogger(__name__)

Expand All @@ -36,7 +37,6 @@ def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Expand All @@ -48,18 +48,11 @@ def __init__(
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._user = user
self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation()
self._stream = stream

def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None):
"""
Handle error event.
:param event: event
:param message: message
:return:
"""
def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
logger.debug("error: %s", event.error)
e = event.error
err: Exception
Expand All @@ -71,16 +64,17 @@ def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = Non
else:
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))

if message:
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()

if refetch_message:
err_desc = self._error_to_desc(err)
refetch_message.status = "error"
refetch_message.error = err_desc
if not message_id or not session:
return err

db.session.commit()
stmt = select(Message).where(Message.id == message_id)
message = session.scalar(stmt)
if not message:
return err

err_desc = self._error_to_desc(err)
message.status = "error"
message.error = err_desc
return err

def _error_to_desc(self, e: Exception) -> str:
Expand Down
Loading
Loading