18
18
Callable ,
19
19
Final ,
20
20
Generic ,
21
- NoReturn ,
22
21
Protocol ,
23
22
TypeVar ,
24
- Union ,
25
23
cast ,
26
24
runtime_checkable ,
27
25
)
75
73
76
74
from pydantic import BaseModel , ValidationError
77
75
76
+ from pydantic_ai .output import DeferredToolCalls
77
+ from pydantic_ai .tools import ToolDefinition
78
+ from pydantic_ai .toolset import AbstractToolset , DeferredToolset
79
+
78
80
from . import Agent , models
79
81
from ._agent_graph import ModelRequestNode
80
82
from .agent import RunOutputDataT
100
102
from .output import OutputDataT , OutputSpec
101
103
from .result import AgentStream
102
104
from .settings import ModelSettings
103
- from .tools import AgentDepsT , Tool
105
+ from .tools import AgentDepsT
104
106
from .usage import Usage , UsageLimits
105
107
106
108
if TYPE_CHECKING :
@@ -139,7 +141,7 @@ def __init__(
139
141
usage_limits : UsageLimits | None = None ,
140
142
usage : Usage | None = None ,
141
143
infer_name : bool = True ,
142
- additional_tools : Sequence [Tool [AgentDepsT ]] | None = None ,
144
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
143
145
# Starlette
144
146
debug : bool = False ,
145
147
routes : Sequence [BaseRoute ] | None = None ,
@@ -164,7 +166,7 @@ def __init__(
164
166
usage_limits: Optional limits on model request count or token usage.
165
167
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
166
168
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
167
- additional_tools: Additional tools to use for this run .
169
+ toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset .
168
170
169
171
debug: Boolean indicating if debug tracebacks should be returned on errors.
170
172
routes: A list of routes to serve incoming HTTP and WebSocket requests.
@@ -218,7 +220,7 @@ async def endpoint(request: Request) -> Response | StreamingResponse:
218
220
usage_limits = usage_limits ,
219
221
usage = usage ,
220
222
infer_name = infer_name ,
221
- additional_tools = additional_tools ,
223
+ toolsets = toolsets ,
222
224
),
223
225
media_type = SSE_CONTENT_TYPE ,
224
226
)
@@ -241,7 +243,7 @@ def agent_to_ag_ui(
241
243
usage_limits : UsageLimits | None = None ,
242
244
usage : Usage | None = None ,
243
245
infer_name : bool = True ,
244
- additional_tools : Sequence [Tool [AgentDepsT ]] | None = None ,
246
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
245
247
# Starlette parameters.
246
248
debug : bool = False ,
247
249
routes : Sequence [BaseRoute ] | None = None ,
@@ -268,7 +270,7 @@ def agent_to_ag_ui(
268
270
usage_limits: Optional limits on model request count or token usage.
269
271
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
270
272
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
271
- additional_tools: Additional tools to use for this run .
273
+ toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset .
272
274
273
275
debug: Boolean indicating if debug tracebacks should be returned on errors.
274
276
routes: A list of routes to serve incoming HTTP and WebSocket requests.
@@ -308,7 +310,7 @@ def agent_to_ag_ui(
308
310
usage_limits = usage_limits ,
309
311
usage = usage ,
310
312
infer_name = infer_name ,
311
- additional_tools = additional_tools ,
313
+ toolsets = toolsets ,
312
314
# Starlette
313
315
debug = debug ,
314
316
routes = routes ,
@@ -402,7 +404,7 @@ async def run(
402
404
usage_limits : UsageLimits | None = None ,
403
405
usage : Usage | None = None ,
404
406
infer_name : bool = True ,
405
- additional_tools : Sequence [Tool [AgentDepsT ]] | None = None ,
407
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
406
408
) -> AsyncGenerator [str , None ]:
407
409
"""Run the agent with streaming response using AG-UI protocol events.
408
410
@@ -420,7 +422,7 @@ async def run(
420
422
usage_limits: Optional limits on model request count or token usage.
421
423
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
422
424
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
423
- additional_tools: Additional tools to use for this run .
425
+ toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset .
424
426
425
427
Yields:
426
428
Streaming SSE-formatted event chunks.
@@ -429,8 +431,9 @@ async def run(
429
431
430
432
tool_names : dict [str , str ] = {self .tool_prefix + tool .name : tool .name for tool in run_input .tools }
431
433
encoder : EventEncoder = EventEncoder (accept = accept )
432
- run_tools : list [Tool [AgentDepsT ]] = list (additional_tools ) if additional_tools else []
433
- run_tools .extend (self ._convert_tools (run_input .tools ))
434
+ run_toolset : list [AbstractToolset [AgentDepsT ]] = list (toolsets ) if toolsets else []
435
+ if run_input .tools :
436
+ run_toolset .append (_AGUIToolset [AgentDepsT ](run_input .tools ))
434
437
435
438
try :
436
439
yield encoder .encode (
@@ -452,17 +455,18 @@ async def run(
452
455
run : AgentRun [AgentDepsT , Any ]
453
456
async with self .agent .iter (
454
457
user_prompt = None ,
455
- output_type = output_type ,
458
+ # TODO(steve): This should be output_type not str.
459
+ output_type = cast (OutputSpec [Any ], [str , DeferredToolCalls ]),
456
460
message_history = history .messages ,
457
461
model = model ,
458
462
deps = deps ,
459
463
model_settings = model_settings ,
460
464
usage_limits = usage_limits ,
461
465
usage = usage ,
462
466
infer_name = infer_name ,
463
- additional_tools = run_tools ,
467
+ toolsets = run_toolset ,
464
468
) as run :
465
- async for event in self ._agent_stream (tool_names , run , history . prompt_message_id ):
469
+ async for event in self ._agent_stream (tool_names , run , history ):
466
470
if event is None :
467
471
# Tool call signals early return, so we stop processing.
468
472
self .logger .debug ('tool call early return' )
@@ -493,14 +497,14 @@ async def run(
493
497
async def _tool_events (
494
498
self ,
495
499
parts : list [ModelRequestPart ],
496
- prompt_message_id : str ,
500
+ history : _History ,
497
501
) -> AsyncGenerator [BaseEvent | None , None ]:
498
502
"""Check for tool call results that are AG-UI events.
499
503
500
504
Args:
501
505
encoder: The event encoder to use for encoding events.
502
506
parts: The list of request parts to check for tool event returns.
503
- prompt_message_id : The message ID of the user prompt to use for tool call results .
507
+ history : The history of messages and tool calls to use for the run .
504
508
505
509
Yields:
506
510
AG-UI Server-Sent Events (SSE).
@@ -510,8 +514,12 @@ async def _tool_events(
510
514
if not isinstance (part , ToolReturnPart ):
511
515
continue
512
516
517
+ if part .tool_call_id in history .tool_calls :
518
+ # Tool call was passed in the history, so we skip it.
519
+ continue
520
+
513
521
yield ToolCallResultEvent (
514
- message_id = prompt_message_id ,
522
+ message_id = history . prompt_message_id ,
515
523
type = EventType .TOOL_CALL_RESULT ,
516
524
role = Role .TOOL .value ,
517
525
tool_call_id = part .tool_call_id ,
@@ -534,64 +542,18 @@ async def _tool_events(
534
542
self .logger .debug ('ag-ui event: %s' , item )
535
543
yield item
536
544
537
- def _convert_tools (self , run_tools : list [ToolAGUI ]) -> list [Tool [AgentDepsT ]]:
538
- """Convert AG-UI tools to PydanticAI tools.
539
-
540
- Creates `Tool` objects from AG-UI tool definitions. These tools don't
541
- actually execute anything, that is done by AG-UI client - they just
542
- provide the necessary tool definitions to PydanticAI agent.
543
-
544
- Args:
545
- run_tools: List of AG-UI tool definitions to convert.
546
-
547
- Returns:
548
- List of PydanticAI Tool objects that call the AG-UI tools.
549
- """
550
- return [self ._tool_call (tool ) for tool in run_tools ]
551
-
552
- def _tool_call (self , tool : ToolAGUI ) -> Tool [AgentDepsT ]:
553
- """Create a PydanticAI tool from an AG-UI tool definition.
554
-
555
- Args:
556
- tool: The AG-UI tool definition to convert.
557
-
558
- Returns:
559
- A PydanticAI `Tool` object that calls the AG-UI tool.
560
- """
561
-
562
- def _tool_stub (* args : Any , ** kwargs : Any ) -> NoReturn :
563
- """Stub function which is never called.
564
-
565
- Returns:
566
- Never returns as it always raises an exception.
567
-
568
- Raises:
569
- _UnexpectedToolCallError: Always raised since this should never be called.
570
- """
571
- raise _UnexpectedToolCallError (tool_name = tool .name ) # pragma: no cover
572
-
573
- return cast (
574
- 'Tool[AgentDepsT]' ,
575
- Tool .from_schema (
576
- function = _tool_stub ,
577
- name = tool .name ,
578
- description = tool .description ,
579
- json_schema = tool .parameters ,
580
- ),
581
- )
582
-
583
545
async def _agent_stream (
584
546
self ,
585
547
tool_names : dict [str , str ],
586
548
run : AgentRun [AgentDepsT , Any ],
587
- prompt_message_id : str ,
549
+ history : _History ,
588
550
) -> AsyncGenerator [BaseEvent | None , None ]:
589
551
"""Run the agent streaming responses using AG-UI protocol events.
590
552
591
553
Args:
592
554
tool_names: A mapping of tool names to their AG-UI names.
593
555
run: The agent run to process.
594
- prompt_message_id : The message ID of the user prompt to use for tool call results .
556
+ history : The history of messages and tool calls to use for the run .
595
557
596
558
Yields:
597
559
AG-UI Server-Sent Events (SSE).
@@ -605,7 +567,7 @@ async def _agent_stream(
605
567
continue
606
568
607
569
# Check for tool results.
608
- async for msg in self ._tool_events (node .request .parts , prompt_message_id ):
570
+ async for msg in self ._tool_events (node .request .parts , history ):
609
571
yield msg
610
572
611
573
stream_ctx : _RequestStreamContext = _RequestStreamContext ()
@@ -616,8 +578,9 @@ async def _agent_stream(
616
578
async for msg in self ._handle_agent_event (tool_names , stream_ctx , agent_event ):
617
579
yield msg
618
580
619
- for part_end in stream_ctx .part_ends :
620
- yield part_end
581
+ if stream_ctx .part_end :
582
+ yield stream_ctx .part_end
583
+ stream_ctx .part_end = None
621
584
622
585
async def _handle_agent_event (
623
586
self ,
@@ -638,11 +601,10 @@ async def _handle_agent_event(
638
601
"""
639
602
self .logger .debug ('agent_event: %s' , agent_event )
640
603
if isinstance (agent_event , PartStartEvent ):
641
- # If we have a previous part end it.
642
- part_end : BaseEvent | None
643
- for part_end in stream_ctx .part_ends :
644
- yield part_end
645
- stream_ctx .part_ends .clear ()
604
+ if stream_ctx .part_end :
605
+ # End the previous part.
606
+ yield stream_ctx .part_end
607
+ stream_ctx .part_end = None
646
608
647
609
if isinstance (agent_event .part , TextPart ):
648
610
message_id : str = stream_ctx .new_message_id ()
@@ -651,12 +613,10 @@ async def _handle_agent_event(
651
613
message_id = message_id ,
652
614
role = Role .ASSISTANT .value ,
653
615
)
654
- stream_ctx .part_ends = [
655
- TextMessageEndEvent (
656
- type = EventType .TEXT_MESSAGE_END ,
657
- message_id = message_id ,
658
- ),
659
- ]
616
+ stream_ctx .part_end = TextMessageEndEvent (
617
+ type = EventType .TEXT_MESSAGE_END ,
618
+ message_id = message_id ,
619
+ )
660
620
if agent_event .part .content :
661
621
yield TextMessageContentEvent ( # pragma: no cover
662
622
type = EventType .TEXT_MESSAGE_CONTENT ,
@@ -671,15 +631,10 @@ async def _handle_agent_event(
671
631
tool_call_id = agent_event .part .tool_call_id ,
672
632
tool_call_name = tool_name or agent_event .part .tool_name ,
673
633
)
674
- stream_ctx .part_ends = [
675
- ToolCallEndEvent (
676
- type = EventType .TOOL_CALL_END ,
677
- tool_call_id = agent_event .part .tool_call_id ,
678
- ),
679
- ]
680
- if tool_name :
681
- # AG-UI tool, signal continuation of the stream.
682
- stream_ctx .part_ends .append (None )
634
+ stream_ctx .part_end = ToolCallEndEvent (
635
+ type = EventType .TOOL_CALL_END ,
636
+ tool_call_id = agent_event .part .tool_call_id ,
637
+ )
683
638
684
639
elif isinstance (agent_event .part , ThinkingPart ): # pragma: no branch
685
640
yield ThinkingTextMessageStartEvent (
@@ -690,11 +645,10 @@ async def _handle_agent_event(
690
645
type = EventType .THINKING_TEXT_MESSAGE_CONTENT ,
691
646
delta = agent_event .part .content ,
692
647
)
693
- stream_ctx .part_ends = [
694
- ThinkingTextMessageEndEvent (
695
- type = EventType .THINKING_TEXT_MESSAGE_END ,
696
- ),
697
- ]
648
+ stream_ctx .part_end = ThinkingTextMessageEndEvent (
649
+ type = EventType .THINKING_TEXT_MESSAGE_END ,
650
+ )
651
+
698
652
elif isinstance (agent_event , PartDeltaEvent ):
699
653
if isinstance (agent_event .delta , TextPartDelta ):
700
654
yield TextMessageContentEvent (
@@ -728,6 +682,7 @@ class _History:
728
682
729
683
prompt_message_id : str # The ID of the last user message.
730
684
messages : list [ModelMessage ]
685
+ tool_calls : set [str ] = field (default_factory = set )
731
686
732
687
733
688
def _convert_history (messages : list [Message ]) -> _History :
@@ -742,7 +697,7 @@ def _convert_history(messages: list[Message]) -> _History:
742
697
msg : Message
743
698
prompt_message_id : str = ''
744
699
result : list [ModelMessage ] = []
745
- tool_calls : dict [str , str ] = {}
700
+ tool_calls : dict [str , str ] = {} # Tool call ID to tool name mapping.
746
701
for msg in messages :
747
702
if isinstance (msg , UserMessage ):
748
703
prompt_message_id = msg .id
@@ -789,6 +744,7 @@ def _convert_history(messages: list[Message]) -> _History:
789
744
return _History (
790
745
prompt_message_id = prompt_message_id ,
791
746
messages = result ,
747
+ tool_calls = set (tool_calls .keys ()),
792
748
)
793
749
794
750
@@ -826,21 +782,6 @@ def __str__(self) -> str:
826
782
return self .message
827
783
828
784
829
- class _UnexpectedToolCallError (_RunError ):
830
- """Exception raised when an unexpected tool call is encountered."""
831
-
832
- def __init__ (self , * , tool_name : str ) -> None :
833
- """Initialize the unexpected tool call error.
834
-
835
- Args:
836
- tool_name: The name of the tool that was unexpectedly called.
837
- """
838
- super ().__init__ (
839
- message = f'unexpected tool call name={ tool_name } ' , # pragma: no cover
840
- code = 'unexpected_tool_call' ,
841
- )
842
-
843
-
844
785
@dataclass
845
786
class _NoMessagesError (_RunError ):
846
787
"""Exception raised when no messages are found in the input."""
@@ -926,7 +867,7 @@ class _RequestStreamContext:
926
867
927
868
message_id : str = ''
928
869
last_tool_call_id : str | None = None
929
- part_ends : list [ BaseEvent | None ] = field ( default_factory = lambda : list [ Union [ BaseEvent , None ]]())
870
+ part_end : BaseEvent | None = None
930
871
931
872
def new_message_id (self ) -> str :
932
873
"""Generate a new message ID for the request stream.
@@ -938,3 +879,15 @@ def new_message_id(self) -> str:
938
879
"""
939
880
self .message_id = str (uuid .uuid4 ())
940
881
return self .message_id
882
+
883
+
884
+ class _AGUIToolset (DeferredToolset [AgentDepsT ]):
885
+ """A toolset that is used for AG-UI."""
886
+
887
+ def __init__ (self , tools : list [ToolAGUI ]) -> None :
888
+ super ().__init__ (
889
+ [
890
+ ToolDefinition (name = tool .name , description = tool .description , parameters_json_schema = tool .parameters )
891
+ for tool in tools
892
+ ]
893
+ )
0 commit comments