Skip to content

Commit 39e14ab

Browse files
committed
refactor: streamline initialization of application_generate_entity and task_state in task pipeline classes
Signed-off-by: -LAN- <[email protected]>
1 parent 1886bb1 commit 39e14ab

File tree

5 files changed

+117
-108
lines changed

5 files changed

+117
-108
lines changed

api/core/app/apps/advanced_chat/generate_task_pipeline.py

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
from models.enums import CreatedByRole
6868
from models.workflow import (
6969
Workflow,
70-
WorkflowNodeExecution,
7170
WorkflowRunStatus,
7271
)
7372

@@ -79,12 +78,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
7978
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
8079
"""
8180

82-
_task_state: WorkflowTaskState
83-
_application_generate_entity: AdvancedChatAppGenerateEntity
84-
_workflow_system_variables: dict[SystemVariableKey, Any]
85-
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
86-
_conversation_name_generate_thread: Optional[Thread] = None
87-
8881
def __init__(
8982
self,
9083
application_generate_entity: AdvancedChatAppGenerateEntity,
@@ -96,10 +89,8 @@ def __init__(
9689
stream: bool,
9790
dialogue_count: int,
9891
) -> None:
99-
super().__init__(
100-
application_generate_entity=application_generate_entity,
101-
queue_manager=queue_manager,
102-
stream=stream,
92+
BasedGenerateTaskPipeline.__init__(
93+
self, application_generate_entity=application_generate_entity, queue_manager=queue_manager, stream=stream
10394
)
10495

10596
if isinstance(user, EndUser):
@@ -112,33 +103,36 @@ def __init__(
112103
self._created_by_role = CreatedByRole.ACCOUNT
113104
else:
114105
raise NotImplementedError(f"User type not supported: {type(user)}")
106+
WorkflowCycleManage.__init__(
107+
self,
108+
application_generate_entity=application_generate_entity,
109+
workflow_system_variables={
110+
SystemVariableKey.QUERY: message.query,
111+
SystemVariableKey.FILES: application_generate_entity.files,
112+
SystemVariableKey.CONVERSATION_ID: conversation.id,
113+
SystemVariableKey.USER_ID: user_session_id,
114+
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
115+
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
116+
SystemVariableKey.WORKFLOW_ID: workflow.id,
117+
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
118+
},
119+
)
120+
121+
self._task_state = WorkflowTaskState()
122+
MessageCycleManage.__init__(
123+
self, application_generate_entity=application_generate_entity, task_state=self._task_state
124+
)
115125

126+
self._application_generate_entity = application_generate_entity
116127
self._workflow_id = workflow.id
117128
self._workflow_features_dict = workflow.features_dict
118-
119129
self._conversation_id = conversation.id
120130
self._conversation_mode = conversation.mode
121-
122131
self._message_id = message.id
123132
self._message_created_at = int(message.created_at.timestamp())
124-
125-
self._workflow_system_variables = {
126-
SystemVariableKey.QUERY: message.query,
127-
SystemVariableKey.FILES: application_generate_entity.files,
128-
SystemVariableKey.CONVERSATION_ID: conversation.id,
129-
SystemVariableKey.USER_ID: user_session_id,
130-
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
131-
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
132-
SystemVariableKey.WORKFLOW_ID: workflow.id,
133-
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
134-
}
135-
136-
self._task_state = WorkflowTaskState()
137-
self._wip_workflow_node_executions = {}
138-
139-
self._conversation_name_generate_thread = None
133+
self._conversation_name_generate_thread: Thread | None = None
140134
self._recorded_files: list[Mapping[str, Any]] = []
141-
self._workflow_run_id = ""
135+
self._workflow_run_id: str = ""
142136

143137
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
144138
"""
@@ -275,7 +269,7 @@ def _process_stream_response(
275269
if isinstance(event, QueuePingEvent):
276270
yield self._ping_stream_response()
277271
elif isinstance(event, QueueErrorEvent):
278-
with Session(db.engine) as session:
272+
with Session(db.engine, expire_on_commit=False) as session:
279273
err = self._handle_error(event=event, session=session, message_id=self._message_id)
280274
session.commit()
281275
yield self._error_to_stream_response(err)
@@ -284,7 +278,7 @@ def _process_stream_response(
284278
# override graph runtime state
285279
graph_runtime_state = event.graph_runtime_state
286280

287-
with Session(db.engine) as session:
281+
with Session(db.engine, expire_on_commit=False) as session:
288282
# init workflow run
289283
workflow_run = self._handle_workflow_run_start(
290284
session=session,
@@ -310,7 +304,7 @@ def _process_stream_response(
310304
if not self._workflow_run_id:
311305
raise ValueError("workflow run not initialized.")
312306

313-
with Session(db.engine) as session:
307+
with Session(db.engine, expire_on_commit=False) as session:
314308
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
315309
workflow_node_execution = self._handle_workflow_node_execution_retried(
316310
session=session, workflow_run=workflow_run, event=event
@@ -329,7 +323,7 @@ def _process_stream_response(
329323
if not self._workflow_run_id:
330324
raise ValueError("workflow run not initialized.")
331325

332-
with Session(db.engine) as session:
326+
with Session(db.engine, expire_on_commit=False) as session:
333327
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
334328
workflow_node_execution = self._handle_node_execution_start(
335329
session=session, workflow_run=workflow_run, event=event
@@ -350,7 +344,7 @@ def _process_stream_response(
350344
if event.node_type in [NodeType.ANSWER, NodeType.END]:
351345
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
352346

353-
with Session(db.engine) as session:
347+
with Session(db.engine, expire_on_commit=False) as session:
354348
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
355349

356350
node_finish_resp = self._workflow_node_finish_to_stream_response(
@@ -364,7 +358,7 @@ def _process_stream_response(
364358
if node_finish_resp:
365359
yield node_finish_resp
366360
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
367-
with Session(db.engine) as session:
361+
with Session(db.engine, expire_on_commit=False) as session:
368362
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
369363

370364
node_finish_resp = self._workflow_node_finish_to_stream_response(
@@ -381,7 +375,7 @@ def _process_stream_response(
381375
if not self._workflow_run_id:
382376
raise ValueError("workflow run not initialized.")
383377

384-
with Session(db.engine) as session:
378+
with Session(db.engine, expire_on_commit=False) as session:
385379
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
386380
parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
387381
session=session,
@@ -395,7 +389,7 @@ def _process_stream_response(
395389
if not self._workflow_run_id:
396390
raise ValueError("workflow run not initialized.")
397391

398-
with Session(db.engine) as session:
392+
with Session(db.engine, expire_on_commit=False) as session:
399393
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
400394
parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
401395
session=session,
@@ -409,7 +403,7 @@ def _process_stream_response(
409403
if not self._workflow_run_id:
410404
raise ValueError("workflow run not initialized.")
411405

412-
with Session(db.engine) as session:
406+
with Session(db.engine, expire_on_commit=False) as session:
413407
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
414408
iter_start_resp = self._workflow_iteration_start_to_stream_response(
415409
session=session,
@@ -423,7 +417,7 @@ def _process_stream_response(
423417
if not self._workflow_run_id:
424418
raise ValueError("workflow run not initialized.")
425419

426-
with Session(db.engine) as session:
420+
with Session(db.engine, expire_on_commit=False) as session:
427421
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
428422
iter_next_resp = self._workflow_iteration_next_to_stream_response(
429423
session=session,
@@ -437,7 +431,7 @@ def _process_stream_response(
437431
if not self._workflow_run_id:
438432
raise ValueError("workflow run not initialized.")
439433

440-
with Session(db.engine) as session:
434+
with Session(db.engine, expire_on_commit=False) as session:
441435
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
442436
iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
443437
session=session,
@@ -454,7 +448,7 @@ def _process_stream_response(
454448
if not graph_runtime_state:
455449
raise ValueError("workflow run not initialized.")
456450

457-
with Session(db.engine) as session:
451+
with Session(db.engine, expire_on_commit=False) as session:
458452
workflow_run = self._handle_workflow_run_success(
459453
session=session,
460454
workflow_run_id=self._workflow_run_id,
@@ -479,7 +473,7 @@ def _process_stream_response(
479473
if not graph_runtime_state:
480474
raise ValueError("graph runtime state not initialized.")
481475

482-
with Session(db.engine) as session:
476+
with Session(db.engine, expire_on_commit=False) as session:
483477
workflow_run = self._handle_workflow_run_partial_success(
484478
session=session,
485479
workflow_run_id=self._workflow_run_id,
@@ -504,7 +498,7 @@ def _process_stream_response(
504498
if not graph_runtime_state:
505499
raise ValueError("graph runtime state not initialized.")
506500

507-
with Session(db.engine) as session:
501+
with Session(db.engine, expire_on_commit=False) as session:
508502
workflow_run = self._handle_workflow_run_failed(
509503
session=session,
510504
workflow_run_id=self._workflow_run_id,
@@ -529,7 +523,7 @@ def _process_stream_response(
529523
break
530524
elif isinstance(event, QueueStopEvent):
531525
if self._workflow_run_id and graph_runtime_state:
532-
with Session(db.engine) as session:
526+
with Session(db.engine, expire_on_commit=False) as session:
533527
workflow_run = self._handle_workflow_run_failed(
534528
session=session,
535529
workflow_run_id=self._workflow_run_id,
@@ -557,7 +551,7 @@ def _process_stream_response(
557551
elif isinstance(event, QueueRetrieverResourcesEvent):
558552
self._handle_retriever_resources(event)
559553

560-
with Session(db.engine) as session:
554+
with Session(db.engine, expire_on_commit=False) as session:
561555
message = self._get_message(session=session)
562556
message.message_metadata = (
563557
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
@@ -566,7 +560,7 @@ def _process_stream_response(
566560
elif isinstance(event, QueueAnnotationReplyEvent):
567561
self._handle_annotation_reply(event)
568562

569-
with Session(db.engine) as session:
563+
with Session(db.engine, expire_on_commit=False) as session:
570564
message = self._get_message(session=session)
571565
message.message_metadata = (
572566
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
@@ -603,7 +597,7 @@ def _process_stream_response(
603597
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
604598

605599
# Save message
606-
with Session(db.engine) as session:
600+
with Session(db.engine, expire_on_commit=False) as session:
607601
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
608602
session.commit()
609603

0 commit comments

Comments
 (0)