67
67
from models .enums import CreatedByRole
68
68
from models .workflow import (
69
69
Workflow ,
70
- WorkflowNodeExecution ,
71
70
WorkflowRunStatus ,
72
71
)
73
72
@@ -79,12 +78,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
79
78
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
80
79
"""
81
80
82
- _task_state : WorkflowTaskState
83
- _application_generate_entity : AdvancedChatAppGenerateEntity
84
- _workflow_system_variables : dict [SystemVariableKey , Any ]
85
- _wip_workflow_node_executions : dict [str , WorkflowNodeExecution ]
86
- _conversation_name_generate_thread : Optional [Thread ] = None
87
-
88
81
def __init__ (
89
82
self ,
90
83
application_generate_entity : AdvancedChatAppGenerateEntity ,
@@ -96,10 +89,8 @@ def __init__(
96
89
stream : bool ,
97
90
dialogue_count : int ,
98
91
) -> None :
99
- super ().__init__ (
100
- application_generate_entity = application_generate_entity ,
101
- queue_manager = queue_manager ,
102
- stream = stream ,
92
+ BasedGenerateTaskPipeline .__init__ (
93
+ self , application_generate_entity = application_generate_entity , queue_manager = queue_manager , stream = stream
103
94
)
104
95
105
96
if isinstance (user , EndUser ):
@@ -112,33 +103,36 @@ def __init__(
112
103
self ._created_by_role = CreatedByRole .ACCOUNT
113
104
else :
114
105
raise NotImplementedError (f"User type not supported: { type (user )} " )
106
+ WorkflowCycleManage .__init__ (
107
+ self ,
108
+ application_generate_entity = application_generate_entity ,
109
+ workflow_system_variables = {
110
+ SystemVariableKey .QUERY : message .query ,
111
+ SystemVariableKey .FILES : application_generate_entity .files ,
112
+ SystemVariableKey .CONVERSATION_ID : conversation .id ,
113
+ SystemVariableKey .USER_ID : user_session_id ,
114
+ SystemVariableKey .DIALOGUE_COUNT : dialogue_count ,
115
+ SystemVariableKey .APP_ID : application_generate_entity .app_config .app_id ,
116
+ SystemVariableKey .WORKFLOW_ID : workflow .id ,
117
+ SystemVariableKey .WORKFLOW_RUN_ID : application_generate_entity .workflow_run_id ,
118
+ },
119
+ )
120
+
121
+ self ._task_state = WorkflowTaskState ()
122
+ MessageCycleManage .__init__ (
123
+ self , application_generate_entity = application_generate_entity , task_state = self ._task_state
124
+ )
115
125
126
+ self ._application_generate_entity = application_generate_entity
116
127
self ._workflow_id = workflow .id
117
128
self ._workflow_features_dict = workflow .features_dict
118
-
119
129
self ._conversation_id = conversation .id
120
130
self ._conversation_mode = conversation .mode
121
-
122
131
self ._message_id = message .id
123
132
self ._message_created_at = int (message .created_at .timestamp ())
124
-
125
- self ._workflow_system_variables = {
126
- SystemVariableKey .QUERY : message .query ,
127
- SystemVariableKey .FILES : application_generate_entity .files ,
128
- SystemVariableKey .CONVERSATION_ID : conversation .id ,
129
- SystemVariableKey .USER_ID : user_session_id ,
130
- SystemVariableKey .DIALOGUE_COUNT : dialogue_count ,
131
- SystemVariableKey .APP_ID : application_generate_entity .app_config .app_id ,
132
- SystemVariableKey .WORKFLOW_ID : workflow .id ,
133
- SystemVariableKey .WORKFLOW_RUN_ID : application_generate_entity .workflow_run_id ,
134
- }
135
-
136
- self ._task_state = WorkflowTaskState ()
137
- self ._wip_workflow_node_executions = {}
138
-
139
- self ._conversation_name_generate_thread = None
133
+ self ._conversation_name_generate_thread : Thread | None = None
140
134
self ._recorded_files : list [Mapping [str , Any ]] = []
141
- self ._workflow_run_id = ""
135
+ self ._workflow_run_id : str = ""
142
136
143
137
def process (self ) -> Union [ChatbotAppBlockingResponse , Generator [ChatbotAppStreamResponse , None , None ]]:
144
138
"""
@@ -275,7 +269,7 @@ def _process_stream_response(
275
269
if isinstance (event , QueuePingEvent ):
276
270
yield self ._ping_stream_response ()
277
271
elif isinstance (event , QueueErrorEvent ):
278
- with Session (db .engine ) as session :
272
+ with Session (db .engine , expire_on_commit = False ) as session :
279
273
err = self ._handle_error (event = event , session = session , message_id = self ._message_id )
280
274
session .commit ()
281
275
yield self ._error_to_stream_response (err )
@@ -284,7 +278,7 @@ def _process_stream_response(
284
278
# override graph runtime state
285
279
graph_runtime_state = event .graph_runtime_state
286
280
287
- with Session (db .engine ) as session :
281
+ with Session (db .engine , expire_on_commit = False ) as session :
288
282
# init workflow run
289
283
workflow_run = self ._handle_workflow_run_start (
290
284
session = session ,
@@ -310,7 +304,7 @@ def _process_stream_response(
310
304
if not self ._workflow_run_id :
311
305
raise ValueError ("workflow run not initialized." )
312
306
313
- with Session (db .engine ) as session :
307
+ with Session (db .engine , expire_on_commit = False ) as session :
314
308
workflow_run = self ._get_workflow_run (session = session , workflow_run_id = self ._workflow_run_id )
315
309
workflow_node_execution = self ._handle_workflow_node_execution_retried (
316
310
session = session , workflow_run = workflow_run , event = event
@@ -329,7 +323,7 @@ def _process_stream_response(
329
323
if not self ._workflow_run_id :
330
324
raise ValueError ("workflow run not initialized." )
331
325
332
- with Session (db .engine ) as session :
326
+ with Session (db .engine , expire_on_commit = False ) as session :
333
327
workflow_run = self ._get_workflow_run (session = session , workflow_run_id = self ._workflow_run_id )
334
328
workflow_node_execution = self ._handle_node_execution_start (
335
329
session = session , workflow_run = workflow_run , event = event
@@ -350,7 +344,7 @@ def _process_stream_response(
350
344
if event .node_type in [NodeType .ANSWER , NodeType .END ]:
351
345
self ._recorded_files .extend (self ._fetch_files_from_node_outputs (event .outputs or {}))
352
346
353
- with Session (db .engine ) as session :
347
+ with Session (db .engine , expire_on_commit = False ) as session :
354
348
workflow_node_execution = self ._handle_workflow_node_execution_success (session = session , event = event )
355
349
356
350
node_finish_resp = self ._workflow_node_finish_to_stream_response (
@@ -364,7 +358,7 @@ def _process_stream_response(
364
358
if node_finish_resp :
365
359
yield node_finish_resp
366
360
elif isinstance (event , QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent ):
367
- with Session (db .engine ) as session :
361
+ with Session (db .engine , expire_on_commit = False ) as session :
368
362
workflow_node_execution = self ._handle_workflow_node_execution_failed (session = session , event = event )
369
363
370
364
node_finish_resp = self ._workflow_node_finish_to_stream_response (
@@ -381,7 +375,7 @@ def _process_stream_response(
381
375
if not self ._workflow_run_id :
382
376
raise ValueError ("workflow run not initialized." )
383
377
384
- with Session (db .engine ) as session :
378
+ with Session (db .engine , expire_on_commit = False ) as session :
385
379
workflow_run = self ._get_workflow_run (session = session , workflow_run_id = self ._workflow_run_id )
386
380
parallel_start_resp = self ._workflow_parallel_branch_start_to_stream_response (
387
381
session = session ,
@@ -395,7 +389,7 @@ def _process_stream_response(
395
389
if not self ._workflow_run_id :
396
390
raise ValueError ("workflow run not initialized." )
397
391
398
- with Session (db .engine ) as session :
392
+ with Session (db .engine , expire_on_commit = False ) as session :
399
393
workflow_run = self ._get_workflow_run (session = session , workflow_run_id = self ._workflow_run_id )
400
394
parallel_finish_resp = self ._workflow_parallel_branch_finished_to_stream_response (
401
395
session = session ,
@@ -409,7 +403,7 @@ def _process_stream_response(
409
403
if not self ._workflow_run_id :
410
404
raise ValueError ("workflow run not initialized." )
411
405
412
- with Session (db .engine ) as session :
406
+ with Session (db .engine , expire_on_commit = False ) as session :
413
407
workflow_run = self ._get_workflow_run (session = session , workflow_run_id = self ._workflow_run_id )
414
408
iter_start_resp = self ._workflow_iteration_start_to_stream_response (
415
409
session = session ,
@@ -423,7 +417,7 @@ def _process_stream_response(
423
417
if not self ._workflow_run_id :
424
418
raise ValueError ("workflow run not initialized." )
425
419
426
- with Session (db .engine ) as session :
420
+ with Session (db .engine , expire_on_commit = False ) as session :
427
421
workflow_run = self ._get_workflow_run (session = session , workflow_run_id = self ._workflow_run_id )
428
422
iter_next_resp = self ._workflow_iteration_next_to_stream_response (
429
423
session = session ,
@@ -437,7 +431,7 @@ def _process_stream_response(
437
431
if not self ._workflow_run_id :
438
432
raise ValueError ("workflow run not initialized." )
439
433
440
- with Session (db .engine ) as session :
434
+ with Session (db .engine , expire_on_commit = False ) as session :
441
435
workflow_run = self ._get_workflow_run (session = session , workflow_run_id = self ._workflow_run_id )
442
436
iter_finish_resp = self ._workflow_iteration_completed_to_stream_response (
443
437
session = session ,
@@ -454,7 +448,7 @@ def _process_stream_response(
454
448
if not graph_runtime_state :
455
449
raise ValueError ("workflow run not initialized." )
456
450
457
- with Session (db .engine ) as session :
451
+ with Session (db .engine , expire_on_commit = False ) as session :
458
452
workflow_run = self ._handle_workflow_run_success (
459
453
session = session ,
460
454
workflow_run_id = self ._workflow_run_id ,
@@ -479,7 +473,7 @@ def _process_stream_response(
479
473
if not graph_runtime_state :
480
474
raise ValueError ("graph runtime state not initialized." )
481
475
482
- with Session (db .engine ) as session :
476
+ with Session (db .engine , expire_on_commit = False ) as session :
483
477
workflow_run = self ._handle_workflow_run_partial_success (
484
478
session = session ,
485
479
workflow_run_id = self ._workflow_run_id ,
@@ -504,7 +498,7 @@ def _process_stream_response(
504
498
if not graph_runtime_state :
505
499
raise ValueError ("graph runtime state not initialized." )
506
500
507
- with Session (db .engine ) as session :
501
+ with Session (db .engine , expire_on_commit = False ) as session :
508
502
workflow_run = self ._handle_workflow_run_failed (
509
503
session = session ,
510
504
workflow_run_id = self ._workflow_run_id ,
@@ -529,7 +523,7 @@ def _process_stream_response(
529
523
break
530
524
elif isinstance (event , QueueStopEvent ):
531
525
if self ._workflow_run_id and graph_runtime_state :
532
- with Session (db .engine ) as session :
526
+ with Session (db .engine , expire_on_commit = False ) as session :
533
527
workflow_run = self ._handle_workflow_run_failed (
534
528
session = session ,
535
529
workflow_run_id = self ._workflow_run_id ,
@@ -557,7 +551,7 @@ def _process_stream_response(
557
551
elif isinstance (event , QueueRetrieverResourcesEvent ):
558
552
self ._handle_retriever_resources (event )
559
553
560
- with Session (db .engine ) as session :
554
+ with Session (db .engine , expire_on_commit = False ) as session :
561
555
message = self ._get_message (session = session )
562
556
message .message_metadata = (
563
557
json .dumps (jsonable_encoder (self ._task_state .metadata )) if self ._task_state .metadata else None
@@ -566,7 +560,7 @@ def _process_stream_response(
566
560
elif isinstance (event , QueueAnnotationReplyEvent ):
567
561
self ._handle_annotation_reply (event )
568
562
569
- with Session (db .engine ) as session :
563
+ with Session (db .engine , expire_on_commit = False ) as session :
570
564
message = self ._get_message (session = session )
571
565
message .message_metadata = (
572
566
json .dumps (jsonable_encoder (self ._task_state .metadata )) if self ._task_state .metadata else None
@@ -603,7 +597,7 @@ def _process_stream_response(
603
597
yield self ._message_replace_to_stream_response (answer = output_moderation_answer )
604
598
605
599
# Save message
606
- with Session (db .engine ) as session :
600
+ with Session (db .engine , expire_on_commit = False ) as session :
607
601
self ._save_message (session = session , graph_runtime_state = graph_runtime_state )
608
602
session .commit ()
609
603
0 commit comments