@@ -85,7 +85,7 @@ class _RequestStreamContext:
85
85
86
86
message_id : str = ''
87
87
last_tool_call_id : str | None = None
88
- part_ends : list [str | None ] = field (default_factory = lambda : list [str | None ]())
88
+ part_ends : list [BaseEvent | None ] = field (default_factory = lambda : list [BaseEvent | None ]())
89
89
local_tool_calls : set [str ] = field (default_factory = set )
90
90
91
91
def new_message_id (self ) -> str :
@@ -217,6 +217,7 @@ async def run(
217
217
Yields:
218
218
Streaming SSE-formatted event chunks.
219
219
"""
220
+ self .logger .warning ('level=%s name=%s' , logging .getLevelName (self .logger .level ), self .logger .name )
220
221
self .logger .debug ('starting run: %s' , json .dumps (run_input .model_dump (), indent = 2 ))
221
222
222
223
tool_names : dict [str , str ] = {self .tool_prefix + tool .name : tool .name for tool in run_input .tools }
@@ -257,13 +258,13 @@ async def run(
257
258
infer_name = infer_name ,
258
259
additional_tools = run_tools ,
259
260
) as run :
260
- async for event in self ._agent_stream (encoder , tool_names , run ):
261
+ async for event in self ._agent_stream (tool_names , run ):
261
262
if event is None :
262
263
# Tool call signals early return, so we stop processing.
263
264
self .logger .debug ('tool call early return' )
264
265
break
265
266
266
- yield event
267
+ yield encoder . encode ( event )
267
268
except RunError as e :
268
269
self .logger .exception ('agent run' )
269
270
yield encoder .encode (
@@ -285,9 +286,7 @@ async def run(
285
286
286
287
self .logger .info ('done thread_id=%s run_id=%s' , run_input .thread_id , run_input .run_id )
287
288
288
- async def _tool_events (
289
- self , encoder : EventEncoder , parts : list [ModelRequestPart ]
290
- ) -> AsyncGenerator [str | None , None ]:
289
+ async def _tool_events (self , parts : list [ModelRequestPart ]) -> AsyncGenerator [BaseEvent | None , None ]:
291
290
"""Check for tool call results that are AG-UI events.
292
291
293
292
Args:
@@ -309,15 +308,15 @@ async def _tool_events(
309
308
match part .content :
310
309
case BaseEvent ():
311
310
self .logger .debug ('ag-ui event: %s' , part .content )
312
- yield encoder . encode ( part .content )
311
+ yield part .content
313
312
case str () | bytes ():
314
313
# Avoid strings and bytes being checked as iterable.
315
314
pass
316
315
case Iterable () as iter :
317
316
for item in iter :
318
317
if isinstance (item , BaseEvent ): # pragma: no branch
319
318
self .logger .debug ('ag-ui event: %s' , item )
320
- yield encoder . encode ( item )
319
+ yield item
321
320
case _: # pragma: no cover
322
321
# Not currently interested in other types.
323
322
pass
@@ -371,51 +370,48 @@ def _tool_stub(*args: Any, **kwargs: Any) -> ToolResult:
371
370
372
371
async def _agent_stream (
373
372
self ,
374
- encoder : EventEncoder ,
375
373
tool_names : dict [str , str ],
376
374
run : AgentRun [AgentDepsT , Any ],
377
- ) -> AsyncGenerator [str | None , None ]:
375
+ ) -> AsyncGenerator [BaseEvent | None , None ]:
378
376
"""Run the agent streaming responses using AG-UI protocol events.
379
377
380
378
Args:
381
- encoder: The event encoder to use for encoding events.
382
379
tool_names: A mapping of tool names to their AG-UI names.
383
380
run: The agent run to process.
384
381
385
382
Yields:
386
383
AG-UI Server-Sent Events (SSE).
387
384
"""
388
385
node : AgentNode [AgentDepsT , Any ] | End [FinalResult [Any ]]
389
- msg : str | None
386
+ msg : BaseEvent | None
390
387
async for node in run :
391
388
self .logger .debug ('processing node=%r' , node )
392
389
if not isinstance (node , ModelRequestNode ):
393
390
# Not interested UserPromptNode, CallToolsNode or End.
394
391
continue
395
392
396
393
# Check for state updates.
397
- snapshot : str | None
398
- async for snapshot in self ._tool_events (encoder , node .request .parts ):
394
+ snapshot : BaseEvent | None
395
+ async for snapshot in self ._tool_events (node .request .parts ):
399
396
yield snapshot
400
397
401
398
stream_ctx : _RequestStreamContext = _RequestStreamContext ()
402
399
request_stream : AgentStream [AgentDepsT ]
403
400
async with node .stream (run .ctx ) as request_stream :
404
401
agent_event : AgentStreamEvent
405
402
async for agent_event in request_stream :
406
- async for msg in self ._handle_agent_event (encoder , tool_names , stream_ctx , agent_event ):
403
+ async for msg in self ._handle_agent_event (tool_names , stream_ctx , agent_event ):
407
404
yield msg
408
405
409
406
for part_end in stream_ctx .part_ends :
410
407
yield part_end
411
408
412
409
async def _handle_agent_event (
413
410
self ,
414
- encoder : EventEncoder ,
415
411
tool_names : dict [str , str ],
416
412
stream_ctx : _RequestStreamContext ,
417
413
agent_event : AgentStreamEvent ,
418
- ) -> AsyncGenerator [str | None , None ]:
414
+ ) -> AsyncGenerator [BaseEvent | None , None ]:
419
415
"""Handle an agent event and yield AG-UI protocol events.
420
416
421
417
Args:
@@ -431,36 +427,30 @@ async def _handle_agent_event(
431
427
match agent_event :
432
428
case PartStartEvent ():
433
429
# If we have a previous part end it.
434
- part_end : str | None
430
+ part_end : BaseEvent | None
435
431
for part_end in stream_ctx .part_ends :
436
432
yield part_end
437
433
stream_ctx .part_ends .clear ()
438
434
439
435
match agent_event .part :
440
436
case TextPart ():
441
437
message_id : str = stream_ctx .new_message_id ()
442
- yield encoder .encode (
443
- TextMessageStartEvent (
444
- type = EventType .TEXT_MESSAGE_START ,
445
- message_id = message_id ,
446
- role = Role .ASSISTANT .value ,
447
- ),
438
+ yield TextMessageStartEvent (
439
+ type = EventType .TEXT_MESSAGE_START ,
440
+ message_id = message_id ,
441
+ role = Role .ASSISTANT .value ,
448
442
)
449
443
stream_ctx .part_ends = [
450
- encoder .encode (
451
- TextMessageEndEvent (
452
- type = EventType .TEXT_MESSAGE_END ,
453
- message_id = message_id ,
454
- ),
444
+ TextMessageEndEvent (
445
+ type = EventType .TEXT_MESSAGE_END ,
446
+ message_id = message_id ,
455
447
),
456
448
]
457
449
if agent_event .part .content :
458
- yield encoder .encode ( # pragma: no cover
459
- TextMessageContentEvent (
460
- type = EventType .TEXT_MESSAGE_CONTENT ,
461
- message_id = message_id ,
462
- delta = agent_event .part .content ,
463
- ),
450
+ yield TextMessageContentEvent ( # pragma: no cover
451
+ type = EventType .TEXT_MESSAGE_CONTENT ,
452
+ message_id = message_id ,
453
+ delta = agent_event .part .content ,
464
454
)
465
455
case ToolCallPart (): # pragma: no branch
466
456
tool_name : str | None = tool_names .get (agent_event .part .tool_name )
@@ -469,19 +459,15 @@ async def _handle_agent_event(
469
459
return
470
460
471
461
stream_ctx .last_tool_call_id = agent_event .part .tool_call_id
472
- yield encoder .encode (
473
- ToolCallStartEvent (
474
- type = EventType .TOOL_CALL_START ,
475
- tool_call_id = agent_event .part .tool_call_id ,
476
- tool_call_name = tool_name or agent_event .part .tool_name ,
477
- ),
462
+ yield ToolCallStartEvent (
463
+ type = EventType .TOOL_CALL_START ,
464
+ tool_call_id = agent_event .part .tool_call_id ,
465
+ tool_call_name = tool_name or agent_event .part .tool_name ,
478
466
)
479
467
stream_ctx .part_ends = [
480
- encoder .encode (
481
- ToolCallEndEvent (
482
- type = EventType .TOOL_CALL_END ,
483
- tool_call_id = agent_event .part .tool_call_id ,
484
- ),
468
+ ToolCallEndEvent (
469
+ type = EventType .TOOL_CALL_END ,
470
+ tool_call_id = agent_event .part .tool_call_id ,
485
471
),
486
472
None , # Signal continuation of the stream.
487
473
]
@@ -491,28 +477,24 @@ async def _handle_agent_event(
491
477
case PartDeltaEvent ():
492
478
match agent_event .delta :
493
479
case TextPartDelta ():
494
- yield encoder .encode (
495
- TextMessageContentEvent (
496
- type = EventType .TEXT_MESSAGE_CONTENT ,
497
- message_id = stream_ctx .message_id ,
498
- delta = agent_event .delta .content_delta ,
499
- ),
480
+ yield TextMessageContentEvent (
481
+ type = EventType .TEXT_MESSAGE_CONTENT ,
482
+ message_id = stream_ctx .message_id ,
483
+ delta = agent_event .delta .content_delta ,
500
484
)
501
485
case ToolCallPartDelta (): # pragma: no branch
502
486
if agent_event .delta .tool_call_id in stream_ctx .local_tool_calls :
503
487
# Local tool calls are not sent to the UI.
504
488
return
505
489
506
- yield encoder .encode (
507
- ToolCallArgsEvent (
508
- type = EventType .TOOL_CALL_ARGS ,
509
- tool_call_id = agent_event .delta .tool_call_id
510
- or stream_ctx .last_tool_call_id
511
- or 'unknown' , # Should never be unknown, but just in case.
512
- delta = agent_event .delta .args_delta
513
- if isinstance (agent_event .delta .args_delta , str )
514
- else json .dumps (agent_event .delta .args_delta ),
515
- ),
490
+ yield ToolCallArgsEvent (
491
+ type = EventType .TOOL_CALL_ARGS ,
492
+ tool_call_id = agent_event .delta .tool_call_id
493
+ or stream_ctx .last_tool_call_id
494
+ or 'unknown' , # Should never be unknown, but just in case.
495
+ delta = agent_event .delta .args_delta
496
+ if isinstance (agent_event .delta .args_delta , str )
497
+ else json .dumps (agent_event .delta .args_delta ),
516
498
)
517
499
case ThinkingPartDelta (): # pragma: no branch
518
500
# No equivalent AG-UI event yet.
0 commit comments