Skip to content

Commit 8ea3901

Browse files
committed
refactor(ag-ui): push encode to the top level
Push encode to the top level, eliminating the need to pass the encoder to lower levels which simplifies the code and makes it more maintainable.
1 parent 8d28862 commit 8ea3901

File tree

1 file changed

+44
-62
lines changed

1 file changed

+44
-62
lines changed

pydantic_ai_ag_ui/pydantic_ai_ag_ui/adapter.py

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class _RequestStreamContext:
8585

8686
message_id: str = ''
8787
last_tool_call_id: str | None = None
88-
part_ends: list[str | None] = field(default_factory=lambda: list[str | None]())
88+
part_ends: list[BaseEvent | None] = field(default_factory=lambda: list[BaseEvent | None]())
8989
local_tool_calls: set[str] = field(default_factory=set)
9090

9191
def new_message_id(self) -> str:
@@ -217,6 +217,7 @@ async def run(
217217
Yields:
218218
Streaming SSE-formatted event chunks.
219219
"""
220+
self.logger.warning('level=%s name=%s', logging.getLevelName(self.logger.level), self.logger.name)
220221
self.logger.debug('starting run: %s', json.dumps(run_input.model_dump(), indent=2))
221222

222223
tool_names: dict[str, str] = {self.tool_prefix + tool.name: tool.name for tool in run_input.tools}
@@ -257,13 +258,13 @@ async def run(
257258
infer_name=infer_name,
258259
additional_tools=run_tools,
259260
) as run:
260-
async for event in self._agent_stream(encoder, tool_names, run):
261+
async for event in self._agent_stream(tool_names, run):
261262
if event is None:
262263
# Tool call signals early return, so we stop processing.
263264
self.logger.debug('tool call early return')
264265
break
265266

266-
yield event
267+
yield encoder.encode(event)
267268
except RunError as e:
268269
self.logger.exception('agent run')
269270
yield encoder.encode(
@@ -285,9 +286,7 @@ async def run(
285286

286287
self.logger.info('done thread_id=%s run_id=%s', run_input.thread_id, run_input.run_id)
287288

288-
async def _tool_events(
289-
self, encoder: EventEncoder, parts: list[ModelRequestPart]
290-
) -> AsyncGenerator[str | None, None]:
289+
async def _tool_events(self, parts: list[ModelRequestPart]) -> AsyncGenerator[BaseEvent | None, None]:
291290
"""Check for tool call results that are AG-UI events.
292291
293292
Args:
@@ -309,15 +308,15 @@ async def _tool_events(
309308
match part.content:
310309
case BaseEvent():
311310
self.logger.debug('ag-ui event: %s', part.content)
312-
yield encoder.encode(part.content)
311+
yield part.content
313312
case str() | bytes():
314313
# Avoid strings and bytes being checked as iterable.
315314
pass
316315
case Iterable() as iter:
317316
for item in iter:
318317
if isinstance(item, BaseEvent): # pragma: no branch
319318
self.logger.debug('ag-ui event: %s', item)
320-
yield encoder.encode(item)
319+
yield item
321320
case _: # pragma: no cover
322321
# Not currently interested in other types.
323322
pass
@@ -371,51 +370,48 @@ def _tool_stub(*args: Any, **kwargs: Any) -> ToolResult:
371370

372371
async def _agent_stream(
373372
self,
374-
encoder: EventEncoder,
375373
tool_names: dict[str, str],
376374
run: AgentRun[AgentDepsT, Any],
377-
) -> AsyncGenerator[str | None, None]:
375+
) -> AsyncGenerator[BaseEvent | None, None]:
378376
"""Run the agent streaming responses using AG-UI protocol events.
379377
380378
Args:
381-
encoder: The event encoder to use for encoding events.
382379
tool_names: A mapping of tool names to their AG-UI names.
383380
run: The agent run to process.
384381
385382
Yields:
386383
AG-UI Server-Sent Events (SSE).
387384
"""
388385
node: AgentNode[AgentDepsT, Any] | End[FinalResult[Any]]
389-
msg: str | None
386+
msg: BaseEvent | None
390387
async for node in run:
391388
self.logger.debug('processing node=%r', node)
392389
if not isinstance(node, ModelRequestNode):
393390
# Not interested UserPromptNode, CallToolsNode or End.
394391
continue
395392

396393
# Check for state updates.
397-
snapshot: str | None
398-
async for snapshot in self._tool_events(encoder, node.request.parts):
394+
snapshot: BaseEvent | None
395+
async for snapshot in self._tool_events(node.request.parts):
399396
yield snapshot
400397

401398
stream_ctx: _RequestStreamContext = _RequestStreamContext()
402399
request_stream: AgentStream[AgentDepsT]
403400
async with node.stream(run.ctx) as request_stream:
404401
agent_event: AgentStreamEvent
405402
async for agent_event in request_stream:
406-
async for msg in self._handle_agent_event(encoder, tool_names, stream_ctx, agent_event):
403+
async for msg in self._handle_agent_event(tool_names, stream_ctx, agent_event):
407404
yield msg
408405

409406
for part_end in stream_ctx.part_ends:
410407
yield part_end
411408

412409
async def _handle_agent_event(
413410
self,
414-
encoder: EventEncoder,
415411
tool_names: dict[str, str],
416412
stream_ctx: _RequestStreamContext,
417413
agent_event: AgentStreamEvent,
418-
) -> AsyncGenerator[str | None, None]:
414+
) -> AsyncGenerator[BaseEvent | None, None]:
419415
"""Handle an agent event and yield AG-UI protocol events.
420416
421417
Args:
@@ -431,36 +427,30 @@ async def _handle_agent_event(
431427
match agent_event:
432428
case PartStartEvent():
433429
# If we have a previous part end it.
434-
part_end: str | None
430+
part_end: BaseEvent | None
435431
for part_end in stream_ctx.part_ends:
436432
yield part_end
437433
stream_ctx.part_ends.clear()
438434

439435
match agent_event.part:
440436
case TextPart():
441437
message_id: str = stream_ctx.new_message_id()
442-
yield encoder.encode(
443-
TextMessageStartEvent(
444-
type=EventType.TEXT_MESSAGE_START,
445-
message_id=message_id,
446-
role=Role.ASSISTANT.value,
447-
),
438+
yield TextMessageStartEvent(
439+
type=EventType.TEXT_MESSAGE_START,
440+
message_id=message_id,
441+
role=Role.ASSISTANT.value,
448442
)
449443
stream_ctx.part_ends = [
450-
encoder.encode(
451-
TextMessageEndEvent(
452-
type=EventType.TEXT_MESSAGE_END,
453-
message_id=message_id,
454-
),
444+
TextMessageEndEvent(
445+
type=EventType.TEXT_MESSAGE_END,
446+
message_id=message_id,
455447
),
456448
]
457449
if agent_event.part.content:
458-
yield encoder.encode( # pragma: no cover
459-
TextMessageContentEvent(
460-
type=EventType.TEXT_MESSAGE_CONTENT,
461-
message_id=message_id,
462-
delta=agent_event.part.content,
463-
),
450+
yield TextMessageContentEvent( # pragma: no cover
451+
type=EventType.TEXT_MESSAGE_CONTENT,
452+
message_id=message_id,
453+
delta=agent_event.part.content,
464454
)
465455
case ToolCallPart(): # pragma: no branch
466456
tool_name: str | None = tool_names.get(agent_event.part.tool_name)
@@ -469,19 +459,15 @@ async def _handle_agent_event(
469459
return
470460

471461
stream_ctx.last_tool_call_id = agent_event.part.tool_call_id
472-
yield encoder.encode(
473-
ToolCallStartEvent(
474-
type=EventType.TOOL_CALL_START,
475-
tool_call_id=agent_event.part.tool_call_id,
476-
tool_call_name=tool_name or agent_event.part.tool_name,
477-
),
462+
yield ToolCallStartEvent(
463+
type=EventType.TOOL_CALL_START,
464+
tool_call_id=agent_event.part.tool_call_id,
465+
tool_call_name=tool_name or agent_event.part.tool_name,
478466
)
479467
stream_ctx.part_ends = [
480-
encoder.encode(
481-
ToolCallEndEvent(
482-
type=EventType.TOOL_CALL_END,
483-
tool_call_id=agent_event.part.tool_call_id,
484-
),
468+
ToolCallEndEvent(
469+
type=EventType.TOOL_CALL_END,
470+
tool_call_id=agent_event.part.tool_call_id,
485471
),
486472
None, # Signal continuation of the stream.
487473
]
@@ -491,28 +477,24 @@ async def _handle_agent_event(
491477
case PartDeltaEvent():
492478
match agent_event.delta:
493479
case TextPartDelta():
494-
yield encoder.encode(
495-
TextMessageContentEvent(
496-
type=EventType.TEXT_MESSAGE_CONTENT,
497-
message_id=stream_ctx.message_id,
498-
delta=agent_event.delta.content_delta,
499-
),
480+
yield TextMessageContentEvent(
481+
type=EventType.TEXT_MESSAGE_CONTENT,
482+
message_id=stream_ctx.message_id,
483+
delta=agent_event.delta.content_delta,
500484
)
501485
case ToolCallPartDelta(): # pragma: no branch
502486
if agent_event.delta.tool_call_id in stream_ctx.local_tool_calls:
503487
# Local tool calls are not sent to the UI.
504488
return
505489

506-
yield encoder.encode(
507-
ToolCallArgsEvent(
508-
type=EventType.TOOL_CALL_ARGS,
509-
tool_call_id=agent_event.delta.tool_call_id
510-
or stream_ctx.last_tool_call_id
511-
or 'unknown', # Should never be unknown, but just in case.
512-
delta=agent_event.delta.args_delta
513-
if isinstance(agent_event.delta.args_delta, str)
514-
else json.dumps(agent_event.delta.args_delta),
515-
),
490+
yield ToolCallArgsEvent(
491+
type=EventType.TOOL_CALL_ARGS,
492+
tool_call_id=agent_event.delta.tool_call_id
493+
or stream_ctx.last_tool_call_id
494+
or 'unknown', # Should never be unknown, but just in case.
495+
delta=agent_event.delta.args_delta
496+
if isinstance(agent_event.delta.args_delta, str)
497+
else json.dumps(agent_event.delta.args_delta),
516498
)
517499
case ThinkingPartDelta(): # pragma: no branch
518500
# No equivalent AG-UI event yet.

0 commit comments

Comments
 (0)