Skip to content

Commit caddcc0

Browse files
committed
chore(ag-ui): use toolsets
Switch from additional tools to the new toolset system.
1 parent 58137bb commit caddcc0

File tree

4 files changed

+76
-119
lines changed

4 files changed

+76
-119
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 65 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
Callable,
1919
Final,
2020
Generic,
21-
NoReturn,
2221
Protocol,
2322
TypeVar,
24-
Union,
2523
cast,
2624
runtime_checkable,
2725
)
@@ -75,6 +73,10 @@
7573

7674
from pydantic import BaseModel, ValidationError
7775

76+
from pydantic_ai.output import DeferredToolCalls
77+
from pydantic_ai.tools import ToolDefinition
78+
from pydantic_ai.toolset import AbstractToolset, DeferredToolset
79+
7880
from . import Agent, models
7981
from ._agent_graph import ModelRequestNode
8082
from .agent import RunOutputDataT
@@ -100,7 +102,7 @@
100102
from .output import OutputDataT, OutputSpec
101103
from .result import AgentStream
102104
from .settings import ModelSettings
103-
from .tools import AgentDepsT, Tool
105+
from .tools import AgentDepsT
104106
from .usage import Usage, UsageLimits
105107

106108
if TYPE_CHECKING:
@@ -139,7 +141,7 @@ def __init__(
139141
usage_limits: UsageLimits | None = None,
140142
usage: Usage | None = None,
141143
infer_name: bool = True,
142-
additional_tools: Sequence[Tool[AgentDepsT]] | None = None,
144+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
143145
# Starlette
144146
debug: bool = False,
145147
routes: Sequence[BaseRoute] | None = None,
@@ -164,7 +166,7 @@ def __init__(
164166
usage_limits: Optional limits on model request count or token usage.
165167
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
166168
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
167-
additional_tools: Additional tools to use for this run.
169+
toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
168170
169171
debug: Boolean indicating if debug tracebacks should be returned on errors.
170172
routes: A list of routes to serve incoming HTTP and WebSocket requests.
@@ -218,7 +220,7 @@ async def endpoint(request: Request) -> Response | StreamingResponse:
218220
usage_limits=usage_limits,
219221
usage=usage,
220222
infer_name=infer_name,
221-
additional_tools=additional_tools,
223+
toolsets=toolsets,
222224
),
223225
media_type=SSE_CONTENT_TYPE,
224226
)
@@ -241,7 +243,7 @@ def agent_to_ag_ui(
241243
usage_limits: UsageLimits | None = None,
242244
usage: Usage | None = None,
243245
infer_name: bool = True,
244-
additional_tools: Sequence[Tool[AgentDepsT]] | None = None,
246+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
245247
# Starlette parameters.
246248
debug: bool = False,
247249
routes: Sequence[BaseRoute] | None = None,
@@ -268,7 +270,7 @@ def agent_to_ag_ui(
268270
usage_limits: Optional limits on model request count or token usage.
269271
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
270272
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
271-
additional_tools: Additional tools to use for this run.
273+
toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
272274
273275
debug: Boolean indicating if debug tracebacks should be returned on errors.
274276
routes: A list of routes to serve incoming HTTP and WebSocket requests.
@@ -308,7 +310,7 @@ def agent_to_ag_ui(
308310
usage_limits=usage_limits,
309311
usage=usage,
310312
infer_name=infer_name,
311-
additional_tools=additional_tools,
313+
toolsets=toolsets,
312314
# Starlette
313315
debug=debug,
314316
routes=routes,
@@ -402,7 +404,7 @@ async def run(
402404
usage_limits: UsageLimits | None = None,
403405
usage: Usage | None = None,
404406
infer_name: bool = True,
405-
additional_tools: Sequence[Tool[AgentDepsT]] | None = None,
407+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
406408
) -> AsyncGenerator[str, None]:
407409
"""Run the agent with streaming response using AG-UI protocol events.
408410
@@ -420,7 +422,7 @@ async def run(
420422
usage_limits: Optional limits on model request count or token usage.
421423
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
422424
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
423-
additional_tools: Additional tools to use for this run.
425+
toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset.
424426
425427
Yields:
426428
Streaming SSE-formatted event chunks.
@@ -429,8 +431,9 @@ async def run(
429431

430432
tool_names: dict[str, str] = {self.tool_prefix + tool.name: tool.name for tool in run_input.tools}
431433
encoder: EventEncoder = EventEncoder(accept=accept)
432-
run_tools: list[Tool[AgentDepsT]] = list(additional_tools) if additional_tools else []
433-
run_tools.extend(self._convert_tools(run_input.tools))
434+
run_toolset: list[AbstractToolset[AgentDepsT]] = list(toolsets) if toolsets else []
435+
if run_input.tools:
436+
run_toolset.append(_AGUIToolset[AgentDepsT](run_input.tools))
434437

435438
try:
436439
yield encoder.encode(
@@ -452,17 +455,18 @@ async def run(
452455
run: AgentRun[AgentDepsT, Any]
453456
async with self.agent.iter(
454457
user_prompt=None,
455-
output_type=output_type,
458+
# TODO(steve): This should be output_type not str.
459+
output_type=cast(OutputSpec[Any], [str, DeferredToolCalls]),
456460
message_history=history.messages,
457461
model=model,
458462
deps=deps,
459463
model_settings=model_settings,
460464
usage_limits=usage_limits,
461465
usage=usage,
462466
infer_name=infer_name,
463-
additional_tools=run_tools,
467+
toolsets=run_toolset,
464468
) as run:
465-
async for event in self._agent_stream(tool_names, run, history.prompt_message_id):
469+
async for event in self._agent_stream(tool_names, run, history):
466470
if event is None:
467471
# Tool call signals early return, so we stop processing.
468472
self.logger.debug('tool call early return')
@@ -493,14 +497,14 @@ async def run(
493497
async def _tool_events(
494498
self,
495499
parts: list[ModelRequestPart],
496-
prompt_message_id: str,
500+
history: _History,
497501
) -> AsyncGenerator[BaseEvent | None, None]:
498502
"""Check for tool call results that are AG-UI events.
499503
500504
Args:
501505
encoder: The event encoder to use for encoding events.
502506
parts: The list of request parts to check for tool event returns.
503-
prompt_message_id: The message ID of the user prompt to use for tool call results.
507+
history: The history of messages and tool calls to use for the run.
504508
505509
Yields:
506510
AG-UI Server-Sent Events (SSE).
@@ -510,8 +514,12 @@ async def _tool_events(
510514
if not isinstance(part, ToolReturnPart):
511515
continue
512516

517+
if part.tool_call_id in history.tool_calls:
518+
# Tool call was passed in the history, so we skip it.
519+
continue
520+
513521
yield ToolCallResultEvent(
514-
message_id=prompt_message_id,
522+
message_id=history.prompt_message_id,
515523
type=EventType.TOOL_CALL_RESULT,
516524
role=Role.TOOL.value,
517525
tool_call_id=part.tool_call_id,
@@ -534,64 +542,18 @@ async def _tool_events(
534542
self.logger.debug('ag-ui event: %s', item)
535543
yield item
536544

537-
def _convert_tools(self, run_tools: list[ToolAGUI]) -> list[Tool[AgentDepsT]]:
538-
"""Convert AG-UI tools to PydanticAI tools.
539-
540-
Creates `Tool` objects from AG-UI tool definitions. These tools don't
541-
actually execute anything, that is done by AG-UI client - they just
542-
provide the necessary tool definitions to PydanticAI agent.
543-
544-
Args:
545-
run_tools: List of AG-UI tool definitions to convert.
546-
547-
Returns:
548-
List of PydanticAI Tool objects that call the AG-UI tools.
549-
"""
550-
return [self._tool_call(tool) for tool in run_tools]
551-
552-
def _tool_call(self, tool: ToolAGUI) -> Tool[AgentDepsT]:
553-
"""Create a PydanticAI tool from an AG-UI tool definition.
554-
555-
Args:
556-
tool: The AG-UI tool definition to convert.
557-
558-
Returns:
559-
A PydanticAI `Tool` object that calls the AG-UI tool.
560-
"""
561-
562-
def _tool_stub(*args: Any, **kwargs: Any) -> NoReturn:
563-
"""Stub function which is never called.
564-
565-
Returns:
566-
Never returns as it always raises an exception.
567-
568-
Raises:
569-
_UnexpectedToolCallError: Always raised since this should never be called.
570-
"""
571-
raise _UnexpectedToolCallError(tool_name=tool.name) # pragma: no cover
572-
573-
return cast(
574-
'Tool[AgentDepsT]',
575-
Tool.from_schema(
576-
function=_tool_stub,
577-
name=tool.name,
578-
description=tool.description,
579-
json_schema=tool.parameters,
580-
),
581-
)
582-
583545
async def _agent_stream(
584546
self,
585547
tool_names: dict[str, str],
586548
run: AgentRun[AgentDepsT, Any],
587-
prompt_message_id: str,
549+
history: _History,
588550
) -> AsyncGenerator[BaseEvent | None, None]:
589551
"""Run the agent streaming responses using AG-UI protocol events.
590552
591553
Args:
592554
tool_names: A mapping of tool names to their AG-UI names.
593555
run: The agent run to process.
594-
prompt_message_id: The message ID of the user prompt to use for tool call results.
556+
history: The history of messages and tool calls to use for the run.
595557
596558
Yields:
597559
AG-UI Server-Sent Events (SSE).
@@ -605,7 +567,7 @@ async def _agent_stream(
605567
continue
606568

607569
# Check for tool results.
608-
async for msg in self._tool_events(node.request.parts, prompt_message_id):
570+
async for msg in self._tool_events(node.request.parts, history):
609571
yield msg
610572

611573
stream_ctx: _RequestStreamContext = _RequestStreamContext()
@@ -616,8 +578,9 @@ async def _agent_stream(
616578
async for msg in self._handle_agent_event(tool_names, stream_ctx, agent_event):
617579
yield msg
618580

619-
for part_end in stream_ctx.part_ends:
620-
yield part_end
581+
if stream_ctx.part_end:
582+
yield stream_ctx.part_end
583+
stream_ctx.part_end = None
621584

622585
async def _handle_agent_event(
623586
self,
@@ -638,11 +601,10 @@ async def _handle_agent_event(
638601
"""
639602
self.logger.debug('agent_event: %s', agent_event)
640603
if isinstance(agent_event, PartStartEvent):
641-
# If we have a previous part end it.
642-
part_end: BaseEvent | None
643-
for part_end in stream_ctx.part_ends:
644-
yield part_end
645-
stream_ctx.part_ends.clear()
604+
if stream_ctx.part_end:
605+
# End the previous part.
606+
yield stream_ctx.part_end
607+
stream_ctx.part_end = None
646608

647609
if isinstance(agent_event.part, TextPart):
648610
message_id: str = stream_ctx.new_message_id()
@@ -651,12 +613,10 @@ async def _handle_agent_event(
651613
message_id=message_id,
652614
role=Role.ASSISTANT.value,
653615
)
654-
stream_ctx.part_ends = [
655-
TextMessageEndEvent(
656-
type=EventType.TEXT_MESSAGE_END,
657-
message_id=message_id,
658-
),
659-
]
616+
stream_ctx.part_end = TextMessageEndEvent(
617+
type=EventType.TEXT_MESSAGE_END,
618+
message_id=message_id,
619+
)
660620
if agent_event.part.content:
661621
yield TextMessageContentEvent( # pragma: no cover
662622
type=EventType.TEXT_MESSAGE_CONTENT,
@@ -671,15 +631,10 @@ async def _handle_agent_event(
671631
tool_call_id=agent_event.part.tool_call_id,
672632
tool_call_name=tool_name or agent_event.part.tool_name,
673633
)
674-
stream_ctx.part_ends = [
675-
ToolCallEndEvent(
676-
type=EventType.TOOL_CALL_END,
677-
tool_call_id=agent_event.part.tool_call_id,
678-
),
679-
]
680-
if tool_name:
681-
# AG-UI tool, signal continuation of the stream.
682-
stream_ctx.part_ends.append(None)
634+
stream_ctx.part_end = ToolCallEndEvent(
635+
type=EventType.TOOL_CALL_END,
636+
tool_call_id=agent_event.part.tool_call_id,
637+
)
683638

684639
elif isinstance(agent_event.part, ThinkingPart): # pragma: no branch
685640
yield ThinkingTextMessageStartEvent(
@@ -690,11 +645,10 @@ async def _handle_agent_event(
690645
type=EventType.THINKING_TEXT_MESSAGE_CONTENT,
691646
delta=agent_event.part.content,
692647
)
693-
stream_ctx.part_ends = [
694-
ThinkingTextMessageEndEvent(
695-
type=EventType.THINKING_TEXT_MESSAGE_END,
696-
),
697-
]
648+
stream_ctx.part_end = ThinkingTextMessageEndEvent(
649+
type=EventType.THINKING_TEXT_MESSAGE_END,
650+
)
651+
698652
elif isinstance(agent_event, PartDeltaEvent):
699653
if isinstance(agent_event.delta, TextPartDelta):
700654
yield TextMessageContentEvent(
@@ -728,6 +682,7 @@ class _History:
728682

729683
prompt_message_id: str # The ID of the last user message.
730684
messages: list[ModelMessage]
685+
tool_calls: set[str] = field(default_factory=set)
731686

732687

733688
def _convert_history(messages: list[Message]) -> _History:
@@ -742,7 +697,7 @@ def _convert_history(messages: list[Message]) -> _History:
742697
msg: Message
743698
prompt_message_id: str = ''
744699
result: list[ModelMessage] = []
745-
tool_calls: dict[str, str] = {}
700+
tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping.
746701
for msg in messages:
747702
if isinstance(msg, UserMessage):
748703
prompt_message_id = msg.id
@@ -789,6 +744,7 @@ def _convert_history(messages: list[Message]) -> _History:
789744
return _History(
790745
prompt_message_id=prompt_message_id,
791746
messages=result,
747+
tool_calls=set(tool_calls.keys()),
792748
)
793749

794750

@@ -826,21 +782,6 @@ def __str__(self) -> str:
826782
return self.message
827783

828784

829-
class _UnexpectedToolCallError(_RunError):
830-
"""Exception raised when an unexpected tool call is encountered."""
831-
832-
def __init__(self, *, tool_name: str) -> None:
833-
"""Initialize the unexpected tool call error.
834-
835-
Args:
836-
tool_name: The name of the tool that was unexpectedly called.
837-
"""
838-
super().__init__(
839-
message=f'unexpected tool call name={tool_name}', # pragma: no cover
840-
code='unexpected_tool_call',
841-
)
842-
843-
844785
@dataclass
845786
class _NoMessagesError(_RunError):
846787
"""Exception raised when no messages are found in the input."""
@@ -926,7 +867,7 @@ class _RequestStreamContext:
926867

927868
message_id: str = ''
928869
last_tool_call_id: str | None = None
929-
part_ends: list[BaseEvent | None] = field(default_factory=lambda: list[Union[BaseEvent, None]]())
870+
part_end: BaseEvent | None = None
930871

931872
def new_message_id(self) -> str:
932873
"""Generate a new message ID for the request stream.
@@ -938,3 +879,15 @@ def new_message_id(self) -> str:
938879
"""
939880
self.message_id = str(uuid.uuid4())
940881
return self.message_id
882+
883+
884+
class _AGUIToolset(DeferredToolset[AgentDepsT]):
885+
"""A toolset that is used for AG-UI."""
886+
887+
def __init__(self, tools: list[ToolAGUI]) -> None:
888+
super().__init__(
889+
[
890+
ToolDefinition(name=tool.name, description=tool.description, parameters_json_schema=tool.parameters)
891+
for tool in tools
892+
]
893+
)

0 commit comments

Comments
 (0)