diff --git a/docs/agents.md b/docs/agents.md index 3fe60f7eb..732039307 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -808,7 +808,7 @@ with capture_run_messages() as messages: # (2)! result = agent.run_sync('Please get me the volume of a box with size 6.') except UnexpectedModelBehavior as e: print('An error occurred:', e) - #> An error occurred: Tool exceeded max retries count of 1 + #> An error occurred: Tool 'calc_volume' exceeded max retries count of 1 print('cause:', repr(e.__cause__)) #> cause: ModelRetry('Please try again.') print('messages:', messages) diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 7f8c5fdd6..33bd4e196 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -29,7 +29,7 @@ Examples of both are shown below; [mcp-run-python](run-python.md) is used as the [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. !!! note - [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before calling [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not managed by PydanticAI. + [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI. The name "HTTP" is used since this implementation will be adapted in future to use the new [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. @@ -47,11 +47,11 @@ from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerSSE server = MCPServerSSE(url='http://localhost:3001/sse') # (1)! -agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! +agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): - async with agent.run_mcp_servers(): # (3)! + async with agent: # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -92,9 +92,8 @@ Will display as follows: !!! note [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be - running and accepting HTTP connections before calling - [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not - managed by PydanticAI. + running and accepting HTTP connections before running the agent. Running the server is not + managed by Pydantic AI. Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport. @@ -118,10 +117,10 @@ from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStreamableHTTP server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! -agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! +agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): - async with agent.run_mcp_servers(): # (3)! + async with agent: # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -138,7 +137,7 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class. !!! note - When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers] context manager is responsible for starting and stopping the server. + When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager is responsible for starting and stopping the server. ```python {title="mcp_stdio_client.py" py="3.10"} from pydantic_ai import Agent @@ -156,11 +155,11 @@ server = MCPServerStdio( # (1)! 'stdio', ] ) -agent = Agent('openai:gpt-4o', mcp_servers=[server]) +agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -180,31 +179,32 @@ call needs. from typing import Any from pydantic_ai import Agent -from pydantic_ai.mcp import CallToolFunc, MCPServerStdio, ToolResult +from pydantic_ai.mcp import MCPServerStdio, ToolResult from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext +from pydantic_ai.toolsets.processed import CallToolFunc async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, - tool_name: str, - args: dict[str, Any], + name: str, + tool_args: dict[str, Any], ) -> ToolResult: """A tool call processor that passes along the deps.""" - return await call_tool(tool_name, args, metadata={'deps': ctx.deps}) + return await call_tool(name, tool_args, metadata={'deps': ctx.deps}) server = MCPServerStdio('python', ['mcp_server.py'], process_tool_call=process_tool_call) agent = Agent( model=TestModel(call_tools=['echo_deps']), deps_type=int, - mcp_servers=[server] + toolsets=[server] ) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Echo with deps set to 42', deps=42) print(result.output) #> {"echo_deps":{"echo":"This is an echo message","deps":42}} @@ -242,7 +242,7 @@ calculator_server = MCPServerSSE( # Both servers might have a tool named 'get_data', but they'll be exposed as: # - 'weather_get_data' # - 'calc_get_data' -agent = Agent('openai:gpt-4o', mcp_servers=[weather_server, calculator_server]) +agent = Agent('openai:gpt-4o', toolsets=[weather_server, calculator_server]) ``` ### Example with Stdio Server @@ -272,7 +272,7 @@ js_server = MCPServerStdio( tool_prefix='js' # Tools will be prefixed with 'js_' ) -agent = Agent('openai:gpt-4o', mcp_servers=[python_server, js_server]) +agent = Agent('openai:gpt-4o', toolsets=[python_server, js_server]) ``` When the model interacts with these servers, it will see the prefixed tool names, but the prefixes will be automatically handled when making tool calls. @@ -359,11 +359,11 @@ from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio server = MCPServerStdio(command='python', args=['generate_svg.py']) -agent = Agent('openai:gpt-4o', mcp_servers=[server]) +agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Create an image of a robot in a punk style.') print(result.output) #> Image file written to robot_punk.svg. diff --git a/docs/output.md b/docs/output.md index f32e403d7..825f2ec23 100644 --- a/docs/output.md +++ b/docs/output.md @@ -200,8 +200,8 @@ async def hand_off_to_sql_agent(ctx: RunContext, query: str) -> list[Row]: return output except UnexpectedModelBehavior as e: # Bubble up potentially retryable errors to the router agent - if (cause := e.__cause__) and hasattr(cause, 'tool_retry'): - raise ModelRetry(f'SQL agent failed: {cause.tool_retry.content}') from e + if (cause := e.__cause__) and isinstance(cause, ModelRetry): + raise ModelRetry(f'SQL agent failed: {cause.message}') from e else: raise diff --git a/mcp-run-python/README.md b/mcp-run-python/README.md index 360ca2347..edd84ddb8 100644 --- a/mcp-run-python/README.md +++ b/mcp-run-python/README.md @@ -52,11 +52,11 @@ server = MCPServerStdio('deno', 'jsr:@pydantic/mcp-run-python', 'stdio', ]) -agent = Agent('claude-3-5-haiku-latest', mcp_servers=[server]) +agent = Agent('claude-3-5-haiku-latest', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025.w diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 4515d18bc..a4ac696cf 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -3,6 +3,8 @@ import asyncio import dataclasses import hashlib +import json +from collections import defaultdict, deque from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar @@ -14,16 +16,18 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._utils import is_async_callable, run_in_executor +from pydantic_ai.toolsets import AbstractToolset +from pydantic_ai.toolsets._run import RunToolset from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage +from .exceptions import ToolRetryError from .output import OutputDataT, OutputSpec from .settings import ModelSettings, merge_model_settings -from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc +from .tools import RunContext, ToolDefinition, ToolKind if TYPE_CHECKING: - from .mcp import MCPServer from .models.instrumented import InstrumentationSettings __all__ = ( @@ -77,11 +81,13 @@ class GraphAgentState: retries: int run_step: int - def increment_retries(self, max_result_retries: int, error: Exception | None = None) -> None: + def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None: self.retries += 1 if self.retries > max_result_retries: - message = f'Exceeded maximum retries ({max_result_retries}) for result validation' + message = f'Exceeded maximum retries ({max_result_retries}) for output validation' if error: + if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None: + error = error.__cause__ raise exceptions.UnexpectedModelBehavior(message) from error else: raise exceptions.UnexpectedModelBehavior(message) @@ -108,15 +114,12 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): history_processors: Sequence[HistoryProcessor[DepsT]] - function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) - mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) - default_retries: int + toolset: RunToolset[DepsT] + sampling_model: models.Model tracer: Tracer instrumentation_settings: InstrumentationSettings | None = None - prepare_tools: ToolsPrepareFunc[DepsT] | None = None - class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): """The base class for all agent nodes. @@ -248,59 +251,27 @@ async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" - function_tool_defs_map: dict[str, ToolDefinition] = {} - run_context = build_run_context(ctx) - - async def add_tool(tool: Tool[DepsT]) -> None: - ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name) - if tool_def := await tool.prepare_tool_def(ctx): - # prepare_tool_def may change tool_def.name - if tool_def.name in function_tool_defs_map: - if tool_def.name != tool.name: - # Prepare tool def may have renamed the tool - raise exceptions.UserError( - f"Renaming tool '{tool.name}' to '{tool_def.name}' conflicts with existing tool." - ) - else: - raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}.') - function_tool_defs_map[tool_def.name] = tool_def - - async def add_mcp_server_tools(server: MCPServer) -> None: - if not server.is_running: - raise exceptions.UserError(f'MCP server is not running: {server}') - tool_defs = await server.list_tools() - for tool_def in tool_defs: - if tool_def.name in function_tool_defs_map: - raise exceptions.UserError( - f"MCP Server '{server}' defines a tool whose name conflicts with existing tool: {tool_def.name!r}. Consider using `tool_prefix` to avoid name conflicts." - ) - function_tool_defs_map[tool_def.name] = tool_def - - await asyncio.gather( - *map(add_tool, ctx.deps.function_tools.values()), - *map(add_mcp_server_tools, ctx.deps.mcp_servers), - ) - function_tool_defs = list(function_tool_defs_map.values()) - if ctx.deps.prepare_tools: - # Prepare the tools using the provided function - # This also acts over tool definitions pulled from MCP servers - function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] + ctx.deps.toolset = await ctx.deps.toolset.prepare_for_run(run_context) output_schema = ctx.deps.output_schema - - output_tools = [] output_object = None - if isinstance(output_schema, _output.ToolOutputSchema): - output_tools = output_schema.tool_defs() - elif isinstance(output_schema, _output.NativeOutputSchema): + if isinstance(output_schema, _output.NativeOutputSchema): output_object = output_schema.object_def # ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema allow_text_output = isinstance(output_schema, _output.TextOutputSchema) + function_tools: list[ToolDefinition] = [] + output_tools: list[ToolDefinition] = [] + for tool_def in ctx.deps.toolset.tool_defs: + if tool_def.kind == 'output': + output_tools.append(tool_def) + else: + function_tools.append(tool_def) + return models.ModelRequestParameters( - function_tools=function_tool_defs, + function_tools=function_tools, output_mode=output_schema.mode, output_tools=output_tools, output_object=output_object, @@ -342,6 +313,7 @@ async def stream( ctx.deps.output_validators, build_run_context(ctx), ctx.deps.usage_limits, + ctx.deps.toolset, ) yield agent_stream # In case the user didn't manually consume the full stream, ensure it is fully consumed here, @@ -437,7 +409,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]): _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( default=None, repr=False ) - _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] @@ -519,46 +490,30 @@ async def _handle_tool_calls( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], tool_calls: list[_messages.ToolCallPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: - output_schema = ctx.deps.output_schema run_context = build_run_context(ctx) - final_result: result.FinalResult[NodeRunEndT] | None = None - parts: list[_messages.ModelRequestPart] = [] - - # first, look for the output tool call - if isinstance(output_schema, _output.ToolOutputSchema): - for call, output_tool in output_schema.find_tool(tool_calls): - try: - result_data = await output_tool.process(call, run_context) - result_data = await _validate_output(result_data, ctx, call) - except _output.ToolRetryError as e: - # TODO: Should only increment retry stuff once per node execution, not for each tool call - # Also, should increment the tool-specific retry count rather than the run retry count - ctx.state.increment_retries(ctx.deps.max_result_retries, e) - parts.append(e.tool_retry) - else: - final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - break + output_parts: list[_messages.ModelRequestPart] = [] + output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1) - # Then build the other request parts based on end strategy - tool_responses: list[_messages.ModelRequestPart] = self._tool_responses async for event in process_function_tools( - tool_calls, - final_result and final_result.tool_name, - final_result and final_result.tool_call_id, - ctx, - tool_responses, + ctx.deps.toolset, tool_calls, None, ctx, output_parts, output_final_result ): yield event - if final_result: - self._next_node = self._handle_final_result(ctx, final_result, tool_responses) + if output_final_result: + final_result = output_final_result[0] + self._next_node = self._handle_final_result(ctx, final_result, output_parts) + elif deferred_tool_calls := ctx.deps.toolset.get_deferred_tool_calls(tool_calls): + if not ctx.deps.output_schema.allows_deferred_tool_calls: + raise exceptions.UserError( + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + ) + final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None) + self._next_node = self._handle_final_result(ctx, final_result, output_parts) else: - if tool_responses: - parts.extend(tool_responses) instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest(parts=parts, instructions=instructions) + _messages.ModelRequest(parts=output_parts, instructions=instructions) ) def _handle_final_result( @@ -591,10 +546,10 @@ async def _handle_text_response( m = _messages.RetryPromptPart( content='Plain text responses are not permitted, please include your response in a tool call', ) - raise _output.ToolRetryError(m) + raise ToolRetryError(m) result_data = await _validate_output(result_data, ctx, None) - except _output.ToolRetryError as e: + except ToolRetryError as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: @@ -607,6 +562,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT deps=ctx.deps.user_deps, model=ctx.deps.model, usage=ctx.state.usage, + sampling_model=ctx.deps.sampling_model, prompt=ctx.deps.prompt, messages=ctx.state.message_history, run_step=ctx.state.run_step, @@ -620,258 +576,264 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str: return hashlib.sha1(identifier).hexdigest()[:6] -async def process_function_tools( # noqa C901 +async def process_function_tools( # noqa: C901 + toolset: AbstractToolset[DepsT], tool_calls: list[_messages.ToolCallPart], - output_tool_name: str | None, - output_tool_call_id: str | None, + final_result: result.FinalResult[NodeRunEndT] | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], output_parts: list[_messages.ModelRequestPart], + output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1), ) -> AsyncIterator[_messages.HandleResponseEvent]: """Process function (i.e., non-result) tool calls in parallel. Also add stub return parts for any other tools that need it. - Because async iterators can't have return values, we use `output_parts` as an output argument. + Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments. """ - stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early' - output_schema = ctx.deps.output_schema - - # we rely on the fact that if we found a result, it's the first output tool in the last - found_used_output_tool = False run_context = build_run_context(ctx) - calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = [] + tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list) for call in tool_calls: - if ( - call.tool_name == output_tool_name - and call.tool_call_id == output_tool_call_id - and not found_used_output_tool - ): - found_used_output_tool = True - output_parts.append( - _messages.ToolReturnPart( + tool_def = toolset.get_tool_def(call.tool_name) + kind = tool_def.kind if tool_def else 'unknown' + tool_calls_by_kind[kind].append(call) + + # First, we handle output tool calls + for call in tool_calls_by_kind['output']: + if final_result: + if final_result.tool_call_id == call.tool_call_id: + part = _messages.ToolReturnPart( tool_name=call.tool_name, content='Final result processed.', tool_call_id=call.tool_call_id, ) - ) - elif tool := ctx.deps.function_tools.get(call.tool_name): - if stub_function_tools: - output_parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) else: - event = _messages.FunctionToolCallEvent(call) - yield event - calls_to_run.append((tool, call)) - elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx): - if stub_function_tools: - # TODO(Marcelo): We should add coverage for this part of the code. - output_parts.append( # pragma: no cover - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) - else: - event = _messages.FunctionToolCallEvent(call) - yield event - calls_to_run.append((mcp_tool, call)) - elif call.tool_name in output_schema.tools: - # if tool_name is in output_schema, it means we found a output tool but an error occurred in - # validation, we don't add another part here - if output_tool_name is not None: yield _messages.FunctionToolCallEvent(call) - if found_used_output_tool: - content = 'Output tool not used - a final result was already processed.' - else: - # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part - content = 'Output tool not used - result failed validation.' part = _messages.ToolReturnPart( tool_name=call.tool_name, - content=content, + content='Output tool not used - a final result was already processed.', tool_call_id=call.tool_call_id, ) yield _messages.FunctionToolResultEvent(part) - output_parts.append(part) - else: - yield _messages.FunctionToolCallEvent(call) - part = _unknown_tool(call.tool_name, call.tool_call_id, ctx) - yield _messages.FunctionToolResultEvent(part) output_parts.append(part) + else: + try: + result_data = await _call_tool(toolset, call, run_context) + except exceptions.UnexpectedModelBehavior as e: + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + raise e + except ToolRetryError as e: + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + yield _messages.FunctionToolCallEvent(call) + output_parts.append(e.tool_retry) + yield _messages.FunctionToolResultEvent(e.tool_retry) + else: + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Final result processed.', + tool_call_id=call.tool_call_id, + ) + output_parts.append(part) + final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - if not calls_to_run: - return + # Then, we handle function tool calls + calls_to_run: list[_messages.ToolCallPart] = [] + if final_result and ctx.deps.end_strategy == 'early': + for call in tool_calls_by_kind['function']: + output_parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + calls_to_run.extend(tool_calls_by_kind['function']) - user_parts: list[_messages.UserPromptPart] = [] + # Then, we handle unknown tool calls + if tool_calls_by_kind['unknown']: + ctx.state.increment_retries(ctx.deps.max_result_retries) + calls_to_run.extend(tool_calls_by_kind['unknown']) - include_content = ( - ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content - ) + for call in calls_to_run: + yield _messages.FunctionToolCallEvent(call) - # Run all tool tasks in parallel - results_by_index: dict[int, _messages.ModelRequestPart] = {} - with ctx.deps.tracer.start_as_current_span( - 'running tools', - attributes={ - 'tools': [call.tool_name for _, call in calls_to_run], - 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', - }, - ): - tasks = [ - asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer, include_content), name=call.tool_name) - for tool, call in calls_to_run - ] - - pending = tasks - while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for task in done: - index = tasks.index(task) - result = task.result() - yield _messages.FunctionToolResultEvent(result) - - if isinstance(result, _messages.RetryPromptPart): - results_by_index[index] = result - elif isinstance(result, _messages.ToolReturnPart): - if isinstance(result.content, _messages.ToolReturn): - tool_return = result.content - if ( - isinstance(tool_return.return_value, _messages.MultiModalContentTypes) - or isinstance(tool_return.return_value, list) - and any( - isinstance(content, _messages.MultiModalContentTypes) - for content in tool_return.return_value # type: ignore - ) - ): - raise exceptions.UserError( - f"{result.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. " - f'Please use `content` instead.' - ) - result.content = tool_return.return_value # type: ignore - result.metadata = tool_return.metadata - if tool_return.content: - user_parts.append( - _messages.UserPromptPart( - content=list(tool_return.content), - timestamp=result.timestamp, - part_kind='user-prompt', - ) - ) - contents: list[Any] - single_content: bool - if isinstance(result.content, list): - contents = result.content # type: ignore - single_content = False - else: - contents = [result.content] - single_content = True - - processed_contents: list[Any] = [] - for content in contents: - if isinstance(content, _messages.ToolReturn): - raise exceptions.UserError( - f"{result.tool_name}'s return contains invalid nested ToolReturn objects. " - f'ToolReturn should be used directly.' - ) - elif isinstance(content, _messages.MultiModalContentTypes): - # Handle direct multimodal content - if isinstance(content, _messages.BinaryContent): - identifier = multi_modal_content_identifier(content.data) - else: - identifier = multi_modal_content_identifier(content.url) - - user_parts.append( - _messages.UserPromptPart( - content=[f'This is file {identifier}:', content], - timestamp=result.timestamp, - part_kind='user-prompt', - ) - ) - processed_contents.append(f'See file {identifier}') - else: - # Handle regular content - processed_contents.append(content) - - if single_content: - result.content = processed_contents[0] - else: - result.content = processed_contents + user_parts: list[_messages.UserPromptPart] = [] - results_by_index[index] = result - else: - assert_never(result) + if calls_to_run: + include_content = ( + ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content + ) - # We append the results at the end, rather than as they are received, to retain a consistent ordering - # This is mostly just to simplify testing - for k in sorted(results_by_index): - output_parts.append(results_by_index[k]) + # Run all tool tasks in parallel + parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {} + with ctx.deps.tracer.start_as_current_span( + 'running tools', + attributes={ + 'tools': [call.tool_name for call in calls_to_run], + 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', + }, + ): + tasks = [ + asyncio.create_task( + _call_function_tool(toolset, call, run_context, ctx.deps.tracer, include_content), + name=call.tool_name, + ) + for call in calls_to_run + ] + + pending = tasks + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + index = tasks.index(task) + tool_result_part, extra_parts = task.result() + yield _messages.FunctionToolResultEvent(tool_result_part) + + parts_by_index[index] = [tool_result_part, *extra_parts] + + # We append the results at the end, rather than as they are received, to retain a consistent ordering + # This is mostly just to simplify testing + for k in sorted(parts_by_index): + output_parts.extend(parts_by_index[k]) + + # Finally, we handle deferred tool calls + for call in tool_calls_by_kind['deferred']: + if final_result: + output_parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + yield _messages.FunctionToolCallEvent(call) output_parts.extend(user_parts) + if final_result: + output_final_result.append(final_result) -async def _tool_from_mcp_server( - tool_name: str, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> Tool[DepsT] | None: - """Call each MCP server to find the tool with the given name. - Args: - tool_name: The name of the tool to find. - ctx: The current run context. +async def _call_function_tool( + toolset: AbstractToolset[DepsT], + tool_call: _messages.ToolCallPart, + run_context: RunContext[DepsT], + tracer: Tracer, + include_content: bool = False, +) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]: + """Run the tool function asynchronously. - Returns: - The tool with the given name, or `None` if no tool with the given name is found. + See . """ + span_attributes = { + 'gen_ai.tool.name': tool_call.tool_name, + # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai + 'gen_ai.tool.call.id': tool_call.tool_call_id, + **({'tool_arguments': tool_call.args_as_json_str()} if include_content else {}), + 'logfire.msg': f'running tool: {tool_call.tool_name}', + # add the JSON schema so these attributes are formatted nicely in Logfire + 'logfire.json_schema': json.dumps( + { + 'type': 'object', + 'properties': { + **( + { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + } + if include_content + else {} + ), + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ), + } + + with tracer.start_as_current_span('running tool', attributes=span_attributes) as span: + try: + tool_result = await _call_tool(toolset, tool_call, run_context) + except ToolRetryError as e: + part = e.tool_retry + if include_content and span.is_recording(): + span.set_attribute('tool_response', part.model_response()) + return (e.tool_retry, []) + + part = _messages.ToolReturnPart( + tool_name=tool_call.tool_name, + content=tool_result, + tool_call_id=tool_call.tool_call_id, + ) - async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any: - # There's no normal situation where the server will not be running at this point, we check just in case - # some weird edge case occurs. - if not server.is_running: # pragma: no cover - raise exceptions.UserError(f'MCP server is not running: {server}') - - if server.process_tool_call is not None: - result = await server.process_tool_call(ctx, server.call_tool, tool_name, args) - else: - result = await server.call_tool(tool_name, args) + if include_content and span.is_recording(): + span.set_attribute('tool_response', part.model_response_str()) - return result + extra_parts: list[_messages.ModelRequestPart] = [] - for server in ctx.deps.mcp_servers: - tools = await server.list_tools() - if tool_name in {tool.name for tool in tools}: # pragma: no branch - return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries) - return None + def process_content(content: Any) -> Any: + if isinstance(content, _messages.ToolReturn): + raise exceptions.UserError( + f"{tool_call.tool_name}'s return contains invalid nested ToolReturn objects. " + f'ToolReturn should be used directly.' + ) + elif isinstance(content, _messages.MultiModalContentTypes): + if isinstance(content, _messages.BinaryContent): + identifier = multi_modal_content_identifier(content.data) + else: + identifier = multi_modal_content_identifier(content.url) + extra_parts.append( + _messages.UserPromptPart( + content=[f'This is file {identifier}:', content], + part_kind='user-prompt', + ) + ) + return f'See file {identifier}' + else: + return content + + if isinstance(tool_result, _messages.ToolReturn): + if ( + isinstance(tool_result.return_value, _messages.MultiModalContentTypes) + or isinstance(tool_result.return_value, list) + and any( + isinstance(content, _messages.MultiModalContentTypes) + for content in tool_result.return_value # type: ignore + ) + ): + raise exceptions.UserError( + f"{tool_call.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. " + f'Please use `content` instead.' + ) -def _unknown_tool( - tool_name: str, - tool_call_id: str, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> _messages.RetryPromptPart: - ctx.state.increment_retries(ctx.deps.max_result_retries) - tool_names = list(ctx.deps.function_tools.keys()) + part.content = tool_result.return_value # type: ignore + part.metadata = tool_result.metadata + if tool_result.content: + extra_parts.append( + _messages.UserPromptPart( + content=list(tool_result.content), + part_kind='user-prompt', + ) + ) + elif isinstance(tool_result, list): + contents = cast(list[Any], tool_result) + part.content = [process_content(content) for content in contents] + else: + part.content = process_content(tool_result) - output_schema = ctx.deps.output_schema - if isinstance(output_schema, _output.ToolOutputSchema): - tool_names.extend(output_schema.tool_names()) + return (part, extra_parts) - if tool_names: - msg = f'Available tools: {", ".join(tool_names)}' - else: - msg = 'No tools available.' - return _messages.RetryPromptPart( - tool_name=tool_name, - tool_call_id=tool_call_id, - content=f'Unknown tool name: {tool_name!r}. {msg}', - ) +async def _call_tool( + toolset: AbstractToolset[DepsT], tool_call: _messages.ToolCallPart, run_context: RunContext[DepsT] +) -> Any: + run_context = dataclasses.replace(run_context, tool_call_id=tool_call.tool_call_id) + args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args) + return await toolset.call_tool(run_context, tool_call.tool_name, args_dict) async def _validate_output( diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 17f686f4b..ab505e3a4 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -3,8 +3,8 @@ import inspect import json from abc import ABC, abstractmethod -from collections.abc import Awaitable, Iterable, Iterator, Sequence -from dataclasses import dataclass, field +from collections.abc import Awaitable, Sequence +from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload from pydantic import TypeAdapter, ValidationError @@ -13,8 +13,9 @@ from . import _function_schema, _utils, messages as _messages from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, UserError +from .exceptions import ModelRetry, ToolRetryError, UserError from .output import ( + DeferredToolCalls, NativeOutput, OutputDataT, OutputMode, @@ -27,6 +28,8 @@ ToolOutput, ) from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition +from .toolsets import AbstractToolset +from .toolsets._run import RunToolset if TYPE_CHECKING: from .profiles import ModelProfile @@ -66,14 +69,6 @@ DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' -class ToolRetryError(Exception): - """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" - - def __init__(self, tool_retry: _messages.RetryPromptPart): - self.tool_retry = tool_retry - super().__init__() - - @dataclass class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]): function: OutputValidatorFunc[AgentDepsT, OutputDataT_inv] @@ -89,6 +84,7 @@ async def validate( result: T, tool_call: _messages.ToolCallPart | None, run_context: RunContext[AgentDepsT], + wrap_validation_errors: bool = True, ) -> T: """Validate a result but calling the function. @@ -96,12 +92,17 @@ async def validate( result: The result data after Pydantic validation the message content. tool_call: The original tool call message, `None` if there was no tool call. run_context: The current run context. + wrap_validation_errors: If true, wrap the validation errors in a retry message. Returns: Result of either the validated result data (ok) or a retry message (Err). """ if self._takes_ctx: - ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None) + ctx = ( + replace(run_context, tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id) + if tool_call + else run_context + ) args = ctx, result else: args = (result,) @@ -114,24 +115,30 @@ async def validate( function = cast(Callable[[Any], T], self.function) result_data = await _utils.run_in_executor(function, *args) except ModelRetry as r: - m = _messages.RetryPromptPart(content=r.message) - if tool_call is not None: - m.tool_name = tool_call.tool_name - m.tool_call_id = tool_call.tool_call_id - raise ToolRetryError(m) from r + if wrap_validation_errors: + m = _messages.RetryPromptPart(content=r.message) + if tool_call is not None: + m.tool_name = tool_call.tool_name + m.tool_call_id = tool_call.tool_call_id + raise ToolRetryError(m) from r + else: + raise r else: return result_data +@dataclass class BaseOutputSchema(ABC, Generic[OutputDataT]): + allows_deferred_tool_calls: bool + @abstractmethod def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: raise NotImplementedError() @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - return {} + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + return None @dataclass(init=False) @@ -163,7 +170,7 @@ def build( ) -> BaseOutputSchema[OutputDataT]: ... @classmethod - def build( + def build( # noqa: C901 cls, output_spec: OutputSpec[OutputDataT], *, @@ -173,79 +180,106 @@ def build( strict: bool | None = None, ) -> BaseOutputSchema[OutputDataT]: """Build an OutputSchema dataclass from an output type.""" - if output_spec is str: - return PlainTextOutputSchema() + raw_outputs = _flatten_output_spec(output_spec) + + outputs = [output for output in raw_outputs if output is not DeferredToolCalls] + allows_deferred_tool_calls = len(outputs) < len(raw_outputs) + if len(outputs) == 0 and allows_deferred_tool_calls: + raise UserError('At least one output type must be provided other than `DeferredToolCalls`.') + + if output := next((output for output in outputs if isinstance(output, NativeOutput)), None): + if len(outputs) > 1: + raise UserError('`NativeOutput` must be the only output type.') - if isinstance(output_spec, NativeOutput): return NativeOutputSchema( - cls._build_processor( - _flatten_output_spec(output_spec.outputs), - name=output_spec.name, - description=output_spec.description, - strict=output_spec.strict, - ) + processor=cls._build_processor( + _flatten_output_spec(output.outputs), + name=output.name, + description=output.description, + strict=output.strict, + ), + allows_deferred_tool_calls=allows_deferred_tool_calls, ) - elif isinstance(output_spec, PromptedOutput): + elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None): + if len(outputs) > 1: + raise UserError('`PromptedOutput` must be the only output type.') + return PromptedOutputSchema( - cls._build_processor( - _flatten_output_spec(output_spec.outputs), - name=output_spec.name, - description=output_spec.description, + processor=cls._build_processor( + _flatten_output_spec(output.outputs), + name=output.name, + description=output.description, ), - template=output_spec.template, + template=output.template, + allows_deferred_tool_calls=allows_deferred_tool_calls, ) text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] tool_outputs: Sequence[ToolOutput[OutputDataT]] = [] other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = [] - for output in _flatten_output_spec(output_spec): + for output in outputs: if output is str: text_outputs.append(cast(type[str], output)) elif isinstance(output, TextOutput): text_outputs.append(output) elif isinstance(output, ToolOutput): tool_outputs.append(output) + elif isinstance(output, NativeOutput): + # We can never get here because this is checked for above. + raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover + elif isinstance(output, PromptedOutput): + # We can never get here because this is checked for above. + raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover else: other_outputs.append(output) - tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict) + toolset = cls._build_toolset(tool_outputs + other_outputs, name=name, description=description, strict=strict) if len(text_outputs) > 0: if len(text_outputs) > 1: - raise UserError('Only one text output is allowed.') + raise UserError('Only one `str` or `TextOutput` is allowed.') text_output = text_outputs[0] text_output_schema = None if isinstance(text_output, TextOutput): text_output_schema = PlainTextOutputProcessor(text_output.output_function) - if len(tools) == 0: - return PlainTextOutputSchema(text_output_schema) + if toolset: + return ToolOrTextOutputSchema( + processor=text_output_schema, toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls + ) else: - return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools) + return PlainTextOutputSchema( + processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls + ) if len(tool_outputs) > 0: - return ToolOutputSchema(tools) + return ToolOutputSchema(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls) if len(other_outputs) > 0: schema = OutputSchemaWithoutMode( processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict), - tools=tools, + toolset=toolset, + allows_deferred_tool_calls=allows_deferred_tool_calls, ) if default_mode: schema = schema.with_default_mode(default_mode) return schema - raise UserError('No output type provided.') # pragma: no cover + raise UserError('At least one output type must be provided.') @staticmethod - def _build_tools( + def _build_toolset( outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], name: str | None = None, description: str | None = None, strict: bool | None = None, - ) -> dict[str, OutputTool[OutputDataT]]: - tools: dict[str, OutputTool[OutputDataT]] = {} + ) -> OutputToolset[Any] | None: + if len(outputs) == 0: + return None + + processors: dict[str, ObjectOutputProcessor[Any]] = {} + tool_defs: list[ToolDefinition] = [] default_name = name or DEFAULT_OUTPUT_TOOL_NAME default_description = description @@ -271,7 +305,7 @@ def _build_tools( i = 1 original_name = name - while name in tools: + while name in processors: i += 1 name = f'{original_name}_{i}' @@ -280,9 +314,26 @@ def _build_tools( strict = default_strict processor = ObjectOutputProcessor(output=output, description=description, strict=strict) - tools[name] = OutputTool(name=name, processor=processor, multiple=multiple) + object_def = processor.object_def + + description = object_def.description + if not description: + description = DEFAULT_OUTPUT_TOOL_DESCRIPTION + if multiple: + description = f'{object_def.name}: {description}' + + tool_def = ToolDefinition( + name=name, + description=description, + parameters_json_schema=object_def.json_schema, + strict=object_def.strict, + outer_typed_dict_key=processor.outer_typed_dict_key, + kind='output', + ) + processors[name] = processor + tool_defs.append(tool_def) - return tools + return OutputToolset(processors=processors, tool_defs=tool_defs) @staticmethod def _build_processor( @@ -314,32 +365,39 @@ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDa @dataclass(init=False) class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]): processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] - _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + _toolset: OutputToolset[Any] | None = None def __init__( self, processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], - tools: dict[str, OutputTool[OutputDataT]], + toolset: OutputToolset[Any] | None, + allows_deferred_tool_calls: bool, ): + super().__init__(allows_deferred_tool_calls) self.processor = processor - self._tools = tools + self._toolset = toolset def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: if mode == 'native': - return NativeOutputSchema(self.processor) + return NativeOutputSchema( + processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls + ) elif mode == 'prompted': - return PromptedOutputSchema(self.processor) + return PromptedOutputSchema( + processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls + ) elif mode == 'tool': - return ToolOutputSchema(self.tools) + return ToolOutputSchema(toolset=self.toolset, allows_deferred_tool_calls=self.allows_deferred_tool_calls) else: assert_never(mode) @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - # We return tools here as they're checked in Agent._register_tool. - # At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time. - return self._tools + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + # We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor. + # At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time, + # but we cover ourselves just in case we end up using the tool output mode. + return self._toolset class TextOutputSchema(OutputSchema[OutputDataT], ABC): @@ -410,7 +468,7 @@ def mode(self) -> OutputMode: def raise_if_unsupported(self, profile: ModelProfile) -> None: """Raise an error if the mode is not supported by the model.""" if not profile.supports_json_schema_output: - raise UserError('Structured output is not supported by the model.') + raise UserError('Native structured output is not supported by the model.') async def process( self, @@ -490,10 +548,11 @@ async def process( @dataclass(init=False) class ToolOutputSchema(OutputSchema[OutputDataT]): - _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + _toolset: OutputToolset[Any] | None = None - def __init__(self, tools: dict[str, OutputTool[OutputDataT]]): - self._tools = tools + def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool): + super().__init__(allows_deferred_tool_calls) + self._toolset = toolset @property def mode(self) -> OutputMode: @@ -505,36 +564,9 @@ def raise_if_unsupported(self, profile: ModelProfile) -> None: raise UserError('Output tools are not supported by the model.') @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - return self._tools - - def tool_names(self) -> list[str]: - """Return the names of the tools.""" - return list(self.tools.keys()) - - def tool_defs(self) -> list[ToolDefinition]: - """Get tool definitions to register with the model.""" - return [t.tool_def for t in self.tools.values()] - - def find_named_tool( - self, parts: Iterable[_messages.ModelResponsePart], tool_name: str - ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: - """Find a tool that matches one of the calls, with a specific name.""" - for part in parts: # pragma: no branch - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if part.tool_name == tool_name: - return part, self.tools[tool_name] - - def find_tool( - self, - parts: Iterable[_messages.ModelResponsePart], - ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]: - """Find a tool that matches one of the calls.""" - for part in parts: - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if result := self.tools.get(part.tool_name): - yield part, result + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + return self._toolset @dataclass(init=False) @@ -542,10 +574,11 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem def __init__( self, processor: PlainTextOutputProcessor[OutputDataT] | None, - tools: dict[str, OutputTool[OutputDataT]], + toolset: OutputToolset[Any] | None, + allows_deferred_tool_calls: bool, ): + super().__init__(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls) self.processor = processor - self._tools = tools @property def mode(self) -> OutputMode: @@ -578,7 +611,7 @@ async def process( class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]): object_def: OutputObjectDefinition outer_typed_dict_key: str | None = None - _validator: SchemaValidator + validator: SchemaValidator _function_schema: _function_schema.FunctionSchema | None = None def __init__( @@ -591,7 +624,7 @@ def __init__( ): if inspect.isfunction(output) or inspect.ismethod(output): self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema) - self._validator = self._function_schema.validator + self.validator = self._function_schema.validator json_schema = self._function_schema.json_schema json_schema['description'] = self._function_schema.description else: @@ -607,7 +640,7 @@ def __init__( type_adapter = TypeAdapter(response_data_typed_dict) # Really a PluggableSchemaValidator, but it's API-compatible - self._validator = cast(SchemaValidator, type_adapter.validator) + self.validator = cast(SchemaValidator, type_adapter.validator) json_schema = _utils.check_object_json_schema( type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) ) @@ -648,11 +681,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(data, str): - output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) - else: - output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + output = self.validate(data, allow_partial) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -662,20 +691,40 @@ async def process( else: raise # pragma: lax no cover + try: + output = await self.call(output, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + ) + raise ToolRetryError(m) from r + else: + raise # pragma: lax no cover + + return output + + def validate( + self, + data: str | dict[str, Any] | None, + allow_partial: bool = False, + ) -> dict[str, Any]: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + else: + return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + + async def call( + self, + output: Any, + run_context: RunContext[AgentDepsT], + ): if k := self.outer_typed_dict_key: output = output[k] if self._function_schema: - try: - output = await self._function_schema.call(output, run_context) - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - content=r.message, - ) - raise ToolRetryError(m) from r - else: - raise # pragma: lax no cover + output = await self._function_schema.call(output, run_context) return output @@ -851,72 +900,46 @@ async def process( @dataclass(init=False) -class OutputTool(Generic[OutputDataT]): - processor: ObjectOutputProcessor[OutputDataT] - tool_def: ToolDefinition +class OutputToolset(AbstractToolset[AgentDepsT]): + """A toolset that contains output tools.""" - def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool): - self.processor = processor - object_def = processor.object_def + _tool_defs: list[ToolDefinition] + processors: dict[str, ObjectOutputProcessor[Any]] + max_retries: int = field(default=1) + output_validators: list[OutputValidator[AgentDepsT, Any]] = field(default_factory=list) - description = object_def.description - if not description: - description = DEFAULT_OUTPUT_TOOL_DESCRIPTION - if multiple: - description = f'{object_def.name}: {description}' + def __init__( + self, + tool_defs: list[ToolDefinition], + processors: dict[str, ObjectOutputProcessor[Any]], + max_retries: int = 1, + output_validators: list[OutputValidator[AgentDepsT, Any]] = [], + ): + self.processors = processors + self._tool_defs = tool_defs + self.max_retries = max_retries + self.output_validators = output_validators - self.tool_def = ToolDefinition( - name=name, - description=description, - parameters_json_schema=object_def.json_schema, - strict=object_def.strict, - outer_typed_dict_key=processor.outer_typed_dict_key, - ) + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + return RunToolset(self, ctx) - async def process( - self, - tool_call: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - allow_partial: bool = False, - wrap_validation_errors: bool = True, - ) -> OutputDataT: - """Process an output message. + @property + def tool_defs(self) -> list[ToolDefinition]: + return self._tool_defs - Args: - tool_call: The tool call from the LLM to validate. - run_context: The current run context. - allow_partial: If true, allow partial validation. - wrap_validation_errors: If true, wrap the validation errors in a retry message. + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.processors[name].validator - Returns: - Either the validated output data (left) or a retry message (right). - """ - try: - output = await self.processor.process( - tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False - ) - except ValidationError as e: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=e.errors(include_url=False, include_context=False), - tool_call_id=tool_call.tool_call_id, - ) - raise ToolRetryError(m) from e - else: - raise # pragma: lax no cover - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=r.message, - tool_call_id=tool_call.tool_call_id, - ) - raise ToolRetryError(m) from r - else: - raise # pragma: lax no cover - else: - return output + def _max_retries_for_tool(self, name: str) -> int: + return self.max_retries + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + output = await self.processors[name].call(tool_args, ctx) + for validator in self.output_validators: + output = await validator.validate(output, None, ctx, wrap_validation_errors=False) + return output def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index bb7f47420..2eb50742f 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -27,10 +27,14 @@ class RunContext(Generic[AgentDepsT]): """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" - prompt: str | Sequence[_messages.UserContent] | None + sampling_model: Model + """The model used for MCP sampling.""" + prompt: str | Sequence[_messages.UserContent] | None = None """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" + retries: dict[str, int] = field(default_factory=dict) + """Number of retries for each tool so far.""" tool_call_id: str | None = None """The ID of the tool call.""" tool_name: str | None = None @@ -40,17 +44,4 @@ class RunContext(Generic[AgentDepsT]): run_step: int = 0 """The current step in the run.""" - def replace_with( - self, - retry: int | None = None, - tool_name: str | None | _utils.Unset = _utils.UNSET, - ) -> RunContext[AgentDepsT]: - # Create a new `RunContext` a new `retry` value and `tool_name`. - kwargs = {} - if retry is not None: - kwargs['retry'] = retry - if tool_name is not _utils.UNSET: # pragma: no branch - kwargs['tool_name'] = tool_name - return dataclasses.replace(self, **kwargs) - __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 53e8416c0..475c7ab03 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -15,7 +15,6 @@ from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated -from pydantic_ai.profiles import ModelProfile from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -31,8 +30,10 @@ usage as _usage, ) from ._agent_graph import HistoryProcessor +from ._output import OutputToolset from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from .output import OutputDataT, OutputSpec +from .profiles import ModelProfile from .result import FinalResult, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -48,6 +49,10 @@ ToolPrepareFunc, ToolsPrepareFunc, ) +from .toolsets import AbstractToolset +from .toolsets.combined import CombinedToolset +from .toolsets.function import FunctionToolset +from .toolsets.prepared import PreparedToolset # Re-exporting like this improves auto-import behavior in PyCharm capture_run_messages = _agent_graph.capture_run_messages @@ -153,11 +158,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( repr=False ) + _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) + _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) + _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) + _toolset: AbstractToolset[AgentDepsT] = dataclasses.field(repr=False) _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False) - _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) - _default_retries: int = dataclasses.field(repr=False) + _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) + _sampling_model: models.Model | models.KnownModelName | str | None = dataclasses.field(repr=False) + + _running_count: int = dataclasses.field(repr=False) + _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) @overload def __init__( @@ -177,16 +188,18 @@ def __init__( output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> None: ... @overload @deprecated( - '`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' + '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' ) def __init__( self, @@ -207,11 +220,43 @@ def __init__( result_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, + ) -> None: ... + + @overload + @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + result_type: type[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, + result_tool_description: str | None = None, + result_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> None: ... def __init__( @@ -232,11 +277,13 @@ def __init__( output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Any, ): """Create an agent. @@ -258,14 +305,16 @@ def __init__( when the agent is first run. model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow before raising an error. - output_retries: The maximum number of retries to allow for result validation, defaults to `retries`. + output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. - prepare_tools: custom method to prepare the tool definition of all tools for each step. + prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. This is useful if you want to customize the definition of multiple tools or you want to register a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer] - for each server you want the agent to connect to. + prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. + This is useful if you want to customize the definition of multiple output tools or you want to register + a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] + toolsets: Toolsets to register with the agent, including MCP servers. defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` @@ -283,6 +332,7 @@ def __init__( history_processors: Optional list of callables to process the message history before sending it to the model. Each processor takes a list of messages and returns a modified list of messages. Processors can be sync or async and are applied in sequence. + sampling_model: The model to use for MCP sampling, if not provided, the agent's model will be used. """ if model is None or defer_model_check: self.model = model @@ -325,10 +375,18 @@ def __init__( warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning) output_retries = result_retries + if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): + warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) + if toolsets is None: + toolsets = mcp_servers + else: + toolsets = [*toolsets, *mcp_servers] + + _utils.validate_empty_kwargs(_deprecated_kwargs) + default_output_mode = ( self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None ) - _utils.validate_empty_kwargs(_deprecated_kwargs) self._output_schema = _output.OutputSchema[OutputDataT].build( output_type, @@ -353,21 +411,38 @@ def __init__( self._system_prompt_functions = [] self._system_prompt_dynamic_functions = {} - self._function_tools = {} - - self._default_retries = retries self._max_result_retries = output_retries if output_retries is not None else retries - self._mcp_servers = mcp_servers self._prepare_tools = prepare_tools + self._prepare_output_tools = prepare_output_tools + + self._output_toolset = self._output_schema.toolset + if self._output_toolset: + self._output_toolset.max_retries = self._max_result_retries + + self._function_toolset = FunctionToolset(tools, max_retries=retries) + self._user_toolsets = toolsets or () + + all_toolsets: list[AbstractToolset[AgentDepsT]] = [] + if self._output_toolset: + all_toolsets.append(self._output_toolset) + all_toolsets.append(self._function_toolset) + all_toolsets.extend(self._user_toolsets) + + # This will raise errors for any name conflicts + self._toolset = CombinedToolset(all_toolsets) + self.history_processors = history_processors or [] - for tool in tools: - if isinstance(tool, Tool): - self._register_tool(tool) - else: - self._register_tool(Tool(tool)) + + self._sampling_model = sampling_model self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) + self._override_sampling_model: ContextVar[_utils.Option[models.Model]] = ContextVar( + '_override_sampling_model', default=None + ) + + self._exit_stack = None + self._running_count = 0 @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: @@ -387,6 +462,8 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -402,6 +479,8 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -418,6 +497,8 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -432,6 +513,8 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -462,6 +545,8 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. + sampling_model: Optional model to use for MCP sampling. Returns: The result of the run. @@ -486,6 +571,8 @@ async def main(): model_settings=model_settings, usage_limits=usage_limits, usage=usage, + toolsets=toolsets, + sampling_model=sampling_model, ) as agent_run: async for _ in agent_run: pass @@ -506,6 +593,8 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -522,6 +611,8 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -539,6 +630,8 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager @@ -554,6 +647,8 @@ async def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -628,6 +723,8 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. + sampling_model: Optional model to use for MCP sampling. Returns: The result of the run. @@ -637,6 +734,8 @@ async def main(): model_used = self._get_model(model) del model + sampling_model_used = self._get_sampling_model(sampling_model) or model_used + if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') @@ -651,6 +750,20 @@ async def main(): output_type_ = output_type or self.output_type + # We consider it a user error if a user tries to restrict the result type while having an output validator that + # may change the result type from the restricted type to something else. Therefore, we consider the following + # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. + output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + + output_toolset = self._output_toolset + if output_schema != self._output_schema or output_validators: + output_toolset = output_schema.toolset + if output_toolset: + output_toolset.max_retries = self._max_result_retries + output_toolset.output_validators = output_validators + if self._prepare_output_tools: + output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) + # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) @@ -665,10 +778,23 @@ async def main(): run_step=0, ) - # We consider it a user error if a user tries to restrict the result type while having an output validator that - # may change the result type from the restricted type to something else. Therefore, we consider the following - # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + run_context = RunContext[AgentDepsT]( + deps=deps, + model=model_used, + usage=usage, + sampling_model=sampling_model_used, + prompt=user_prompt, + messages=state.message_history, + run_step=state.run_step, + ) + + user_toolsets = self._user_toolsets if toolsets is None else toolsets + toolset = CombinedToolset([self._function_toolset, *user_toolsets]) + if self._prepare_tools: + toolset = PreparedToolset(toolset, self._prepare_tools) + if output_toolset: + toolset = CombinedToolset([output_toolset, toolset]) + run_toolset = await toolset.prepare_for_run(run_context) model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or _usage.UsageLimits() @@ -705,10 +831,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: return None return '\n\n'.join(parts).strip() - # Copy the function tools so that retry state is agent-run-specific - # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`. - run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()} - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( user_deps=deps, prompt=user_prompt, @@ -721,11 +843,9 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: output_schema=output_schema, output_validators=output_validators, history_processors=self.history_processors, - function_tools=run_function_tools, - mcp_servers=self._mcp_servers, - default_retries=self._default_retries, + toolset=run_toolset, + sampling_model=sampling_model_used, tracer=tracer, - prepare_tools=self._prepare_tools, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, ) @@ -795,6 +915,8 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -810,6 +932,8 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -826,6 +950,8 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -840,6 +966,8 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -869,6 +997,8 @@ def run_sync( usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. + sampling_model: Optional model to use for MCP sampling. Returns: The result of the run. @@ -895,6 +1025,8 @@ def run_sync( usage_limits=usage_limits, usage=usage, infer_name=False, + toolsets=toolsets, + sampling_model=sampling_model, ) ) @@ -910,6 +1042,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -925,6 +1058,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @overload @@ -941,6 +1075,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -956,6 +1091,8 @@ async def run_stream( # noqa C901 usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -983,6 +1120,8 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. + sampling_model: Optional model to use for MCP sampling. Returns: The result of the run. @@ -1013,6 +1152,8 @@ async def main(): usage_limits=usage_limits, usage=usage, infer_name=False, + toolsets=toolsets, + sampling_model=sampling_model, ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node @@ -1033,15 +1174,17 @@ async def stream_to_final( output_schema, _output.TextOutputSchema ): return FinalResult(s, None, None) - elif isinstance(new_part, _messages.ToolCallPart) and isinstance( - output_schema, _output.ToolOutputSchema - ): # pragma: no branch - for call, _ in output_schema.find_tool([new_part]): - return FinalResult(s, call.tool_name, call.tool_call_id) + elif isinstance(new_part, _messages.ToolCallPart) and ( + tool_def := graph_ctx.deps.toolset.get_tool_def(new_part.tool_name) + ): + if tool_def.kind == 'output': + return FinalResult(s, new_part.tool_name, new_part.tool_call_id) + elif tool_def.kind == 'deferred': + return FinalResult(s, None, None) return None - final_result_details = await stream_to_final(streamed_response) - if final_result_details is not None: + final_result = await stream_to_final(streamed_response) + if final_result is not None: if yielded: raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover yielded = True @@ -1062,17 +1205,13 @@ async def on_complete() -> None: parts: list[_messages.ModelRequestPart] = [] async for _event in _agent_graph.process_function_tools( + graph_ctx.deps.toolset, tool_calls, - final_result_details.tool_name, - final_result_details.tool_call_id, + final_result, graph_ctx, parts, ): pass - # TODO: Should we do something here related to the retry count? - # Maybe we should move the incrementing of the retry count to where we actually make a request? - # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): - # ctx.state.increment_retries(ctx.deps.max_result_retries) if parts: messages.append(_messages.ModelRequest(parts)) @@ -1084,8 +1223,9 @@ async def on_complete() -> None: graph_ctx.deps.output_schema, _agent_graph.build_run_context(graph_ctx), graph_ctx.deps.output_validators, - final_result_details.tool_name, + final_result.tool_name, on_complete, + graph_ctx.deps.toolset, ) break next_node = await agent_run.next(node) @@ -1104,6 +1244,7 @@ def override( *, deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + sampling_model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies and model. @@ -1113,6 +1254,7 @@ def override( Args: deps: The dependencies to use instead of the dependencies passed to the agent run. model: The model to use instead of the model passed to the agent run. + sampling_model: The model to use for MCP sampling instead of the sampling model passed to the agent run. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) @@ -1124,6 +1266,11 @@ def override( else: model_token = None + if _utils.is_set(sampling_model): + sampling_model_token = self._override_sampling_model.set(_utils.Some(models.infer_model(sampling_model))) + else: + sampling_model_token = None + try: yield finally: @@ -1131,6 +1278,8 @@ def override( self._override_deps.reset(deps_token) if model_token is not None: self._override_model.reset(model_token) + if sampling_model_token is not None: + self._override_sampling_model.reset(sampling_model_token) @overload def instructions( @@ -1418,7 +1567,7 @@ def tool_decorator( func_: ToolFuncContext[AgentDepsT, ToolParams], ) -> ToolFuncContext[AgentDepsT, ToolParams]: # noinspection PyTypeChecker - self._register_function( + self._function_toolset.register_function( func_, True, name, @@ -1434,7 +1583,7 @@ def tool_decorator( return tool_decorator else: # noinspection PyTypeChecker - self._register_function( + self._function_toolset.register_function( func, True, name, @@ -1525,7 +1674,7 @@ async def spam(ctx: RunContext[str]) -> float: def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: # noinspection PyTypeChecker - self._register_function( + self._function_toolset.register_function( func_, False, name, @@ -1540,7 +1689,7 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams return tool_decorator else: - self._register_function( + self._function_toolset.register_function( func, False, name, @@ -1553,47 +1702,6 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams ) return func - def _register_function( - self, - func: ToolFuncEither[AgentDepsT, ToolParams], - takes_ctx: bool, - name: str | None, - retries: int | None, - prepare: ToolPrepareFunc[AgentDepsT] | None, - docstring_format: DocstringFormat, - require_parameter_descriptions: bool, - schema_generator: type[GenerateJsonSchema], - strict: bool | None, - ) -> None: - """Private utility to register a function as a tool.""" - retries_ = retries if retries is not None else self._default_retries - tool = Tool[AgentDepsT]( - func, - takes_ctx=takes_ctx, - name=name, - max_retries=retries_, - prepare=prepare, - docstring_format=docstring_format, - require_parameter_descriptions=require_parameter_descriptions, - schema_generator=schema_generator, - strict=strict, - ) - self._register_tool(tool) - - def _register_tool(self, tool: Tool[AgentDepsT]) -> None: - """Private utility to register a tool instance.""" - if tool.max_retries is None: - # noinspection PyTypeChecker - tool = dataclasses.replace(tool, max_retries=self._default_retries) - - if tool.name in self._function_tools: - raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}') - - if tool.name in self._output_schema.tools: - raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}') - - self._function_tools[tool.name] = tool - def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model: """Create a model configured for this agent. @@ -1638,6 +1746,19 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: else: return deps + def _get_sampling_model( + self, sampling_model: models.Model | models.KnownModelName | str | None + ) -> models.Model | None: + """Get the sampling model for a run.""" + if some_sampling_model := self._override_sampling_model.get(): + return some_sampling_model.value + elif sampling_model is not None: + return models.infer_model(sampling_model) + elif self._sampling_model is not None: + return models.infer_model(self._sampling_model) + else: + return None + def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. @@ -1723,7 +1844,24 @@ def is_end_node( """ return isinstance(node, End) + async def __aenter__(self) -> Self: + if self._running_count == 0: + self._exit_stack = AsyncExitStack() + await self._exit_stack.enter_async_context(self._toolset) + self._running_count += 1 + return self + + async def __aexit__(self, *args: Any) -> bool | None: + self._running_count -= 1 + if self._running_count <= 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + return None + @asynccontextmanager + @deprecated( + '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set an MCP sampling model, use `with agent.override(sampling_model=...)`.' + ) async def run_mcp_servers( self, model: models.Model | models.KnownModelName | str | None = None ) -> AsyncIterator[None]: @@ -1731,20 +1869,9 @@ async def run_mcp_servers( Returns: a context manager to start and shutdown the servers. """ - try: - sampling_model: models.Model | None = self._get_model(model) - except exceptions.UserError: # pragma: no cover - sampling_model = None - - exit_stack = AsyncExitStack() - try: - for mcp_server in self._mcp_servers: - if sampling_model is not None: # pragma: no branch - exit_stack.enter_context(mcp_server.override_sampling_model(sampling_model)) - await exit_stack.enter_async_context(mcp_server) - yield - finally: - await exit_stack.aclose() + with self.override(sampling_model=model or _utils.UNSET): + async with self: + yield def to_a2a( self, diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 078347825..01f599a9f 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -2,12 +2,16 @@ import json import sys +from typing import TYPE_CHECKING if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup # pragma: lax no cover else: ExceptionGroup = ExceptionGroup # pragma: lax no cover +if TYPE_CHECKING: + from .messages import RetryPromptPart + __all__ = ( 'ModelRetry', 'UserError', @@ -113,3 +117,11 @@ def __init__(self, status_code: int, model_name: str, body: object | None = None class FallbackExceptionGroup(ExceptionGroup): """A group of exceptions that can be raised when all fallback models fail.""" + + +class ToolRetryError(Exception): + """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" + + def __init__(self, tool_retry: RetryPromptPart): + self.tool_retry = tool_retry + super().__init__() diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 9d13adda0..83fc6b146 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -3,6 +3,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai.tools import Tool +from pydantic_ai.toolsets.function import FunctionToolset class LangChainTool(Protocol): @@ -23,7 +24,7 @@ def description(self) -> str: ... def run(self, *args: Any, **kwargs: Any) -> str: ... -__all__ = ('tool_from_langchain',) +__all__ = ('tool_from_langchain', 'LangChainToolset') def tool_from_langchain(langchain_tool: LangChainTool) -> Tool: @@ -59,3 +60,10 @@ def proxy(*args: Any, **kwargs: Any) -> str: description=function_description, json_schema=schema, ) + + +class LangChainToolset(FunctionToolset): + """A toolset that wraps LangChain tools.""" + + def __init__(self, tools: list[LangChainTool]): + super().__init__([tool_from_langchain(tool) for tool in tools]) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 2b68832d0..c62137d65 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -3,8 +3,8 @@ import base64 import functools from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence -from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager, nullcontext from contextvars import ContextVar from dataclasses import dataclass from pathlib import Path @@ -17,6 +17,14 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from typing_extensions import Self, assert_never, deprecated +from pydantic_ai._run_context import RunContext +from pydantic_ai.tools import ToolDefinition + +from .toolsets import AbstractToolset +from .toolsets._run import RunToolset +from .toolsets.prefixed import PrefixedToolset +from .toolsets.processed import ProcessedToolset, ToolProcessFunc + try: from mcp import types as mcp_types from mcp.client.session import ClientSession, LoggingFnT @@ -33,12 +41,12 @@ ) from _import_error # after mcp imports so any import error maps to this file, not _mcp.py -from . import _mcp, exceptions, messages, models, tools +from . import _mcp, exceptions, messages, models __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP' -class MCPServer(ABC): +class MCPServer(AbstractToolset[Any], ABC): """Base class for attaching agents to MCP servers. See for more information. @@ -49,8 +57,9 @@ class MCPServer(ABC): log_level: mcp_types.LoggingLevel | None = None log_handler: LoggingFnT | None = None timeout: float = 5 - process_tool_call: ProcessToolCallback | None = None + process_tool_call: ToolProcessFunc[Any] | None = None allow_sampling: bool = True + max_retries: int = 1 # } end of "abstract fields" _running_count: int = 0 @@ -91,48 +100,48 @@ async def client_streams( raise NotImplementedError('MCP Server subclasses must implement this method.') yield - def get_prefixed_tool_name(self, tool_name: str) -> str: - """Get the tool name with prefix if `tool_prefix` is set.""" - return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name - - def get_unprefixed_tool_name(self, tool_name: str) -> str: - """Get original tool name without prefix for calling tools.""" - return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name - @property def is_running(self) -> bool: """Check if the MCP server is running.""" return bool(self._running_count) - async def list_tools(self) -> list[tools.ToolDefinition]: + @property + def name(self) -> str: + return repr(self) + + @property + def tool_name_conflict_hint(self) -> str: + return 'Consider setting `tool_prefix` to avoid name conflicts.' + + async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. Note: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. """ - mcp_tools = await self._client.list_tools() - return [ - tools.ToolDefinition( - name=self.get_prefixed_tool_name(tool.name), - description=tool.description or '', - parameters_json_schema=tool.inputSchema, - ) - for tool in mcp_tools.tools - ] + async with self: # Ensure server is running + result = await self._client.list_tools() + return result.tools async def call_tool( self, - tool_name: str, - arguments: dict[str, Any], + ctx: RunContext[Any], + name: str, + tool_args: dict[str, Any], + *args: Any, metadata: dict[str, Any] | None = None, + **kwargs: Any, ) -> ToolResult: """Call a tool on the server. Args: - tool_name: The name of the tool to call. - arguments: The arguments to pass to the tool. + ctx: The run context of the tool call. + name: The name of the tool to call. + tool_args: The arguments to pass to the tool. + *args: Additional arguments passed by a tool call processor. metadata: Request-level metadata (optional) + **kwargs: Additional keyword arguments passed by a tool call processor. Returns: The result of the tool call. @@ -140,23 +149,28 @@ async def call_tool( Raises: ModelRetry: If the tool call fails. """ - try: - # meta param is not provided by session yet, so build and can send_request directly. - result = await self._client.send_request( - mcp_types.ClientRequest( - mcp_types.CallToolRequest( - method='tools/call', - params=mcp_types.CallToolRequestParams( - name=self.get_unprefixed_tool_name(tool_name), - arguments=arguments, - _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, + sampling_contextmanager = ( + nullcontext() if self._get_sampling_model() else self.override_sampling_model(ctx.sampling_model) + ) + with sampling_contextmanager: + async with self: # Ensure server is running + try: + # meta param is not provided by session yet, so build and can send_request directly. + result = await self._client.send_request( + mcp_types.ClientRequest( + mcp_types.CallToolRequest( + method='tools/call', + params=mcp_types.CallToolRequestParams( + name=name, + arguments=tool_args, + _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, + ), + ) ), + mcp_types.CallToolResult, ) - ), - mcp_types.CallToolResult, - ) - except McpError as e: - raise exceptions.ModelRetry(e.error.message) + except McpError as e: + raise exceptions.ModelRetry(e.error.message) content = [self._map_tool_result_part(part) for part in result.content] @@ -166,6 +180,40 @@ async def call_tool( else: return content[0] if len(content) == 1 else content + async def prepare_for_run(self, ctx: RunContext[Any]) -> RunToolset[Any]: + frozen_toolset = RunToolset(self, ctx, await self.list_tool_defs()) + if self.process_tool_call: + frozen_toolset = await ProcessedToolset(frozen_toolset, self.process_tool_call).prepare_for_run(ctx) + if self.tool_prefix: + frozen_toolset = await PrefixedToolset(frozen_toolset, self.tool_prefix).prepare_for_run(ctx) + return RunToolset(frozen_toolset, ctx, original=self) + + @property + def tool_defs(self) -> list[ToolDefinition]: + # The actual tool definitions are loaded in `prepare_for_run` and cached on the `RunToolset` that will wrap us + return [] + + async def list_tool_defs(self) -> list[ToolDefinition]: + mcp_tools = await self.list_tools() + return [ + ToolDefinition( + name=mcp_tool.name, + description=mcp_tool.description or '', + parameters_json_schema=mcp_tool.inputSchema, + ) + for mcp_tool in mcp_tools + ] + + def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_core.SchemaValidator: + return pydantic_core.SchemaValidator( + schema=pydantic_core.core_schema.dict_schema( + pydantic_core.core_schema.str_schema(), pydantic_core.core_schema.any_schema() + ) + ) + + def _max_retries_for_tool(self, name: str) -> int: + return self.max_retries + async def __aenter__(self) -> Self: if self._running_count == 0: self._exit_stack = AsyncExitStack() @@ -197,11 +245,14 @@ async def __aexit__( if self._running_count <= 0: await self._exit_stack.aclose() + def _get_sampling_model(self) -> models.Model | None: + return self._override_sampling_model.get() or self.sampling_model + async def _sampling_callback( self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: """MCP sampling callback.""" - sampling_model = self._override_sampling_model.get() or self.sampling_model + sampling_model = self._get_sampling_model() if sampling_model is None: raise ValueError('Sampling model is not set') # pragma: no cover @@ -289,10 +340,10 @@ class MCPServerStdio(MCPServer): 'stdio', ] ) - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` @@ -339,12 +390,15 @@ async def main(): timeout: float = 5 """The timeout in seconds to wait for the client to initialize.""" - process_tool_call: ProcessToolCallback | None = None + process_tool_call: ToolProcessFunc[Any] | None = None """Hook to customize tool calling and optionally pass extra metadata.""" allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + max_retries: int = 1 + """The maximum number of times to retry a tool call.""" + @asynccontextmanager async def client_streams( self, @@ -434,12 +488,15 @@ class _MCPServerHTTP(MCPServer): If the connection cannot be established within this time, the operation will fail. """ - process_tool_call: ProcessToolCallback | None = None + process_tool_call: ToolProcessFunc[Any] | None = None """Hook to customize tool calling and optionally pass extra metadata.""" allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + max_retries: int = 1 + """The maximum number of times to retry a tool call.""" + @property @abstractmethod def _transport_client( @@ -521,10 +578,10 @@ class MCPServerSSE(_MCPServerHTTP): from pydantic_ai.mcp import MCPServerSSE server = MCPServerSSE('http://localhost:3001/sse') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` @@ -555,10 +612,10 @@ class MCPServerHTTP(MCPServerSSE): from pydantic_ai.mcp import MCPServerHTTP server = MCPServerHTTP('http://localhost:3001/sse') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` @@ -584,10 +641,10 @@ class MCPServerStreamableHTTP(_MCPServerHTTP): from pydantic_ai.mcp import MCPServerStreamableHTTP server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` """ @@ -604,24 +661,4 @@ def _transport_client(self): | list[Any] | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] ) -"""The result type of a tool call.""" - -CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]] -"""A function type that represents a tool call.""" - -ProcessToolCallback = Callable[ - [ - tools.RunContext[Any], - CallToolFunc, - str, - dict[str, Any], - ], - Awaitable[ToolResult], -] -"""A process tool callback. - -It accepts a run context, the original tool call function, a tool name, and arguments. - -Allows wrapping an MCP server tool call to customize it, including adding extra request -metadata. -""" +"""The result type of an MCP tool call.""" diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 246823292..316921781 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -6,7 +6,8 @@ from typing_extensions import TypeAliasType, TypeVar -from .tools import RunContext +from .messages import ToolCallPart +from .tools import RunContext, ToolDefinition __all__ = ( # classes @@ -290,3 +291,11 @@ def split_into_words(text: str) -> list[str]: See [output docs](../output.md) for more information. """ + + +@dataclass +class DeferredToolCalls: + """Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools.""" + + tool_calls: list[ToolCallPart] + tool_defs: dict[str, ToolDefinition] diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index a8b46f029..5f732aa86 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -3,13 +3,15 @@ import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from copy import copy -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from datetime import datetime -from typing import Generic +from typing import Generic, cast from pydantic import ValidationError from typing_extensions import TypeVar, deprecated, overload +from pydantic_ai.toolsets._run import RunToolset + from . import _utils, exceptions, messages as _messages, models from ._output import ( OutputDataT_inv, @@ -47,6 +49,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None + _toolset: RunToolset[AgentDepsT] _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) _final_result_event: FinalResultEvent | None = field(default=None, init=False) @@ -90,33 +93,44 @@ async def _validate_response( self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" - call = None if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None: - match = self._output_schema.find_named_tool(message.parts, output_tool_name) - if match is None: + tool_call = next( + ( + part + for part in message.parts + if isinstance(part, _messages.ToolCallPart) and part.tool_name == output_tool_name + ), + None, + ) + if tool_call is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool call for {output_tool_name!r}' ) - - call, output_tool = match - result_data = await output_tool.process( - call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + run_context = replace(self._run_ctx, tool_call_id=tool_call.tool_call_id) + args_dict = self._toolset.validate_tool_args( + run_context, tool_call.tool_name, tool_call.args, allow_partial=allow_partial ) + return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict) + elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + if not self._output_schema.allows_deferred_tool_calls: + raise exceptions.UserError( + 'There are deferred tool calls but DeferredToolCalls is not among output types.' + ) + return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, None, self._run_ctx) + return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) - return result_data - def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. @@ -134,13 +148,19 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" if isinstance(e, _messages.PartStartEvent): new_part = e.part - if isinstance(new_part, _messages.ToolCallPart) and isinstance(output_schema, ToolOutputSchema): - for call, _ in output_schema.find_tool([new_part]): # pragma: no branch - return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id) - elif isinstance(new_part, _messages.TextPart) and isinstance( + if isinstance(new_part, _messages.TextPart) and isinstance( output_schema, TextOutputSchema ): # pragma: no branch return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, _messages.ToolCallPart) and ( + tool_def := self._toolset.get_tool_def(new_part.tool_name) + ): + if tool_def.kind == 'output': + return _messages.FinalResultEvent( + tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id + ) + elif tool_def.kind == 'deferred': + return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) usage_checking_stream = _get_usage_checking_stream_response( self._raw_stream_response, self._usage_limits, self.usage @@ -175,6 +195,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] + _toolset: RunToolset[AgentDepsT] _initial_run_ctx_usage: Usage = field(init=False) is_complete: bool = field(default=False, init=False) @@ -408,33 +429,44 @@ async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" - call = None if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None: - match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) - if match is None: + tool_call = next( + ( + part + for part in message.parts + if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name + ), + None, + ) + if tool_call is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool call for {self._output_tool_name!r}' ) - - call, output_tool = match - result_data = await output_tool.process( - call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + run_context = replace(self._run_ctx, tool_call_id=tool_call.tool_call_id) + args_dict = self._toolset.validate_tool_args( + run_context, tool_call.tool_name, tool_call.args, allow_partial=allow_partial ) + return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict) + elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + if not self._output_schema.allows_deferred_tool_calls: + raise exceptions.UserError( + 'There are deferred tool calls but DeferredToolCalls is not among output types.' + ) + return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, None, self._run_ctx) # pragma: no cover + return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover - return result_data - async def _validate_text_output(self, text: str) -> str: for validator in self._output_validators: text = await validator.validate(text, None, self._run_ctx) # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index c22630d72..e79d932a8 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,20 +1,15 @@ from __future__ import annotations as _annotations -import dataclasses -import json from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, Literal, Union -from opentelemetry.trace import Tracer -from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import SchemaValidator, core_schema from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar -from . import _function_schema, _utils, messages as _messages +from . import _function_schema, _utils from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, UnexpectedModelBehavior __all__ = ( 'AgentDepsT', @@ -32,7 +27,6 @@ 'ToolDefinition', ) -from .messages import ToolReturnPart ToolParams = ParamSpec('ToolParams', default=...) """Retrieval function param spec.""" @@ -173,12 +167,6 @@ class Tool(Generic[AgentDepsT]): This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request. """ - # TODO: Consider moving this current_retry state to live on something other than the tool. - # We've worked around this for now by copying instances of the tool when creating new runs, - # but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things - # up, though is also likely a larger effort to refactor. - current_retry: int = field(default=0, init=False) - def __init__( self, function: ToolFuncEither[AgentDepsT], @@ -303,6 +291,15 @@ def from_schema( function_schema=function_schema, ) + @property + def tool_def(self): + return ToolDefinition( + name=self.name, + description=self.description, + parameters_json_schema=self.function_schema.json_schema, + strict=self.strict, + ) + async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. @@ -312,113 +309,11 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition Returns: return a `ToolDefinition` or `None` if the tools should not be registered for this run. """ - tool_def = ToolDefinition( - name=self.name, - description=self.description, - parameters_json_schema=self.function_schema.json_schema, - strict=self.strict, - ) + base_tool_def = self.tool_def if self.prepare is not None: - return await self.prepare(ctx, tool_def) + return await self.prepare(ctx, base_tool_def) else: - return tool_def - - async def run( - self, - message: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - tracer: Tracer, - include_content: bool = False, - ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: - """Run the tool function asynchronously. - - This method wraps `_run` in an OpenTelemetry span. - - See . - """ - span_attributes = { - 'gen_ai.tool.name': self.name, - # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai - 'gen_ai.tool.call.id': message.tool_call_id, - **({'tool_arguments': message.args_as_json_str()} if include_content else {}), - 'logfire.msg': f'running tool: {self.name}', - # add the JSON schema so these attributes are formatted nicely in Logfire - 'logfire.json_schema': json.dumps( - { - 'type': 'object', - 'properties': { - **( - { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, - } - if include_content - else {} - ), - 'gen_ai.tool.name': {}, - 'gen_ai.tool.call.id': {}, - }, - } - ), - } - with tracer.start_as_current_span('running tool', attributes=span_attributes) as span: - response = await self._run(message, run_context) - if include_content and span.is_recording(): - span.set_attribute( - 'tool_response', - response.model_response_str() - if isinstance(response, ToolReturnPart) - else response.model_response(), - ) - - return response - - async def _run( - self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] - ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: - try: - validator = self.function_schema.validator - if isinstance(message.args, str): - args_dict = validator.validate_json(message.args or '{}') - else: - args_dict = validator.validate_python(message.args or {}) - except ValidationError as e: - return self._on_error(e, message) - - ctx = dataclasses.replace( - run_context, - retry=self.current_retry, - tool_name=message.tool_name, - tool_call_id=message.tool_call_id, - ) - try: - response_content = await self.function_schema.call(args_dict, ctx) - except ModelRetry as e: - return self._on_error(e, message) - - self.current_retry = 0 - return _messages.ToolReturnPart( - tool_name=message.tool_name, - content=response_content, - tool_call_id=message.tool_call_id, - ) - - def _on_error( - self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart - ) -> _messages.RetryPromptPart: - self.current_retry += 1 - if self.max_retries is None or self.current_retry > self.max_retries: - raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc - else: - if isinstance(exc, ValidationError): - content = exc.errors(include_url=False, include_context=False) - else: - content = exc.message - return _messages.RetryPromptPart( - tool_name=call_message.tool_name, - content=content, - tool_call_id=call_message.tool_call_id, - ) + return base_tool_def ObjectJsonSchema: TypeAlias = dict[str, Any] @@ -429,6 +324,9 @@ def _on_error( With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any` """ +ToolKind: TypeAlias = Literal['function', 'output', 'deferred'] +"""Kind of tool.""" + @dataclass(repr=False) class ToolDefinition: @@ -464,4 +362,12 @@ class ToolDefinition: Note: this is currently only supported by OpenAI models. """ + kind: ToolKind = field(default='function') + """The kind of tool: + - `'function'`: a tool that can be executed by Pydantic AI and has its result returned to the model + - `'output'`: a tool that passes through an output value that ends the run + - `'deferred'`: a tool that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via e.g. [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools). + When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call. + """ + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py new file mode 100644 index 000000000..66caa1678 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from types import TracebackType +from typing import TYPE_CHECKING, Any, Generic, Literal + +from pydantic_core import SchemaValidator +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition + +if TYPE_CHECKING: + from ._run import RunToolset + + +class AbstractToolset(ABC, Generic[AgentDepsT]): + """A toolset is a collection of tools that can be used by an agent. + + It is responsible for: + - Listing the tools it contains + - Validating the arguments of the tools + - Calling the tools + """ + + @property + def name(self) -> str: + return self.__class__.__name__.replace('Toolset', ' toolset') + + @property + def tool_name_conflict_hint(self) -> str: + return 'Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.' + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + return None + + @abstractmethod + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + raise NotImplementedError() + + @property + @abstractmethod + def tool_defs(self) -> list[ToolDefinition]: + raise NotImplementedError() + + @property + def tool_names(self) -> list[str]: + return [tool_def.name for tool_def in self.tool_defs] + + def get_tool_def(self, name: str) -> ToolDefinition | None: + return next((tool_def for tool_def in self.tool_defs if tool_def.name == name), None) + + @abstractmethod + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + raise NotImplementedError() + + def validate_tool_args( + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False + ) -> dict[str, Any]: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + validator = self._get_tool_args_validator(ctx, name) + if isinstance(args, str): + return validator.validate_json(args or '{}', allow_partial=pyd_allow_partial) + else: + return validator.validate_python(args or {}, allow_partial=pyd_allow_partial) + + @abstractmethod + def _max_retries_for_tool(self, name: str) -> int: + raise NotImplementedError() + + @abstractmethod + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + raise NotImplementedError() diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py new file mode 100644 index 000000000..88ccbd728 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ( + ToolDefinition, + ToolPrepareFunc, +) +from ._mapped import MappedToolset +from ._run import RunToolset +from .wrapper import WrapperToolset + + +@dataclass +class IndividuallyPreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a per-tool prepare function.""" + + prepare_func: ToolPrepareFunc[AgentDepsT] + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + + tool_defs: dict[str, ToolDefinition] = {} + name_map: dict[str, str] = {} + for original_tool_def in wrapped_for_run.tool_defs: + original_name = original_tool_def.name + tool_def = await self.prepare_func(ctx, original_tool_def) + if not tool_def: + continue + + new_name = tool_def.name + if new_name in tool_defs: + if new_name != original_name: + raise UserError(f'Renaming tool {original_name!r} to {new_name!r} conflicts with existing tool.') + else: + raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') + name_map[new_name] = original_name + + tool_defs[new_name] = tool_def + + mapped_for_run = await MappedToolset(wrapped_for_run, list(tool_defs.values()), name_map).prepare_for_run(ctx) + return RunToolset(mapped_for_run, ctx, original=self) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py b/pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py new file mode 100644 index 000000000..47a7a9e72 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from pydantic_core import SchemaValidator + +from .._run_context import AgentDepsT, RunContext +from ..tools import ( + ToolDefinition, +) +from . import AbstractToolset +from ._run import RunToolset +from .wrapper import WrapperToolset + + +@dataclass(init=False) +class MappedToolset(WrapperToolset[AgentDepsT]): + """A toolset that maps renamed tool names to original tool names. Used by `IndividuallyPreparedToolset` as the prepare function may rename a tool.""" + + name_map: dict[str, str] + _tool_defs: list[ToolDefinition] + + def __init__( + self, + wrapped: AbstractToolset[AgentDepsT], + tool_defs: list[ToolDefinition], + name_map: dict[str, str], + ): + super().__init__(wrapped) + self._tool_defs = tool_defs + self.name_map = name_map + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + mapped_for_run = MappedToolset(wrapped_for_run, self._tool_defs, self.name_map) + return RunToolset(mapped_for_run, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self._tool_defs + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return super()._get_tool_args_validator(ctx, self._map_name(name)) + + def _max_retries_for_tool(self, name: str) -> int: + return super()._max_retries_for_tool(self._map_name(name)) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await super().call_tool(ctx, self._map_name(name), tool_args, *args, **kwargs) + + def _map_name(self, name: str) -> str: + return self.name_map.get(name, name) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_run.py b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py new file mode 100644 index 000000000..411ee35c9 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from collections.abc import Iterable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass, replace +from typing import Any + +from pydantic import ValidationError + +from pydantic_ai.output import DeferredToolCalls + +from .. import messages as _messages +from .._run_context import AgentDepsT, RunContext +from ..exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior +from ..tools import ToolDefinition +from . import AbstractToolset +from .wrapper import WrapperToolset + + +@dataclass(init=False) +class RunToolset(WrapperToolset[AgentDepsT]): + """A toolset that caches the wrapped toolset's tool definitions for a specific run step and handles retries.""" + + ctx: RunContext[AgentDepsT] + _tool_defs: list[ToolDefinition] + _tool_names: list[str] + _retries: dict[str, int] + _original: AbstractToolset[AgentDepsT] + + def __init__( + self, + wrapped: AbstractToolset[AgentDepsT], + ctx: RunContext[AgentDepsT], + tool_defs: list[ToolDefinition] | None = None, + original: AbstractToolset[AgentDepsT] | None = None, + ): + self.wrapped = wrapped + self.ctx = ctx + self._tool_defs = wrapped.tool_defs if tool_defs is None else tool_defs + self._tool_names = [tool_def.name for tool_def in self._tool_defs] + self._retries = ctx.retries.copy() + self._original = original or wrapped + + @property + def name(self) -> str: + return self.wrapped.name + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + if ctx == self.ctx: + return self + else: + ctx = replace(ctx, retries=self._retries) + return await self._original.prepare_for_run(ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self._tool_defs + + @property + def tool_names(self) -> list[str]: + return self._tool_names + + def validate_tool_args( + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False + ) -> dict[str, Any]: + with self._with_retry(name, ctx) as ctx: + return super().validate_tool_args(ctx, name, args, allow_partial) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + with self._with_retry(name, ctx) as ctx: + try: + output = await super().call_tool(ctx, name, tool_args, *args, **kwargs) + except Exception as e: + raise e + else: + self._retries.pop(name, None) + return output + + def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None: + deferred_calls_and_defs = [ + (part, tool_def) + for part in parts + if isinstance(part, _messages.ToolCallPart) + and (tool_def := self.get_tool_def(part.tool_name)) + and tool_def.kind == 'deferred' + ] + if not deferred_calls_and_defs: + return None + + deferred_calls: list[_messages.ToolCallPart] = [] + deferred_tool_defs: dict[str, ToolDefinition] = {} + for part, tool_def in deferred_calls_and_defs: + deferred_calls.append(part) + deferred_tool_defs[part.tool_name] = tool_def + + return DeferredToolCalls(deferred_calls, deferred_tool_defs) + + @contextmanager + def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunContext[AgentDepsT]]: + try: + if name not in self.tool_names: + if self.tool_names: + msg = f'Available tools: {", ".join(f"{name!r}" for name in self.tool_names)}' + else: + msg = 'No tools available.' + raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') + + ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={}) + yield ctx + except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e: + try: + max_retries = self._max_retries_for_tool(name) + except Exception: + max_retries = 1 + current_retry = self._retries.get(name, 0) + + if isinstance(e, UnexpectedModelBehavior) and e.__cause__ is not None: + e = e.__cause__ + + if current_retry == max_retries: + raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e + else: + if ctx.tool_call_id: + if isinstance(e, ValidationError): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.errors(include_url=False, include_context=False), + tool_call_id=ctx.tool_call_id, + ) + e = ToolRetryError(m) + elif isinstance(e, ModelRetry): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.message, + tool_call_id=ctx.tool_call_id, + ) + e = ToolRetryError(m) + + self._retries[name] = current_retry + 1 + raise e diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py new file mode 100644 index 000000000..738f82a3c --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from contextlib import AsyncExitStack +from dataclasses import dataclass +from types import TracebackType +from typing import TYPE_CHECKING, Any + +from pydantic_core import SchemaValidator +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ToolDefinition +from . import AbstractToolset +from ._run import RunToolset + +if TYPE_CHECKING: + pass + + +@dataclass(init=False) +class CombinedToolset(AbstractToolset[AgentDepsT]): + """A toolset that combines multiple toolsets.""" + + toolsets: list[AbstractToolset[AgentDepsT]] + _toolset_per_tool_name: dict[str, AbstractToolset[AgentDepsT]] + _exit_stack: AsyncExitStack | None + _running_count: int + + def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): + self._exit_stack = None + self._running_count = 0 + self.toolsets = list(toolsets) + + self._toolset_per_tool_name = {} + for toolset in self.toolsets: + for name in toolset.tool_names: + try: + existing_toolset = self._toolset_per_tool_name[name] + raise UserError( + f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}' + ) + except KeyError: + pass + self._toolset_per_tool_name[name] = toolset + + async def __aenter__(self) -> Self: + if self._running_count == 0: + self._exit_stack = AsyncExitStack() + for toolset in self.toolsets: + await self._exit_stack.enter_async_context(toolset) + self._running_count += 1 + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + self._running_count -= 1 + if self._running_count <= 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + return None + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + toolsets_for_run = await asyncio.gather(*[toolset.prepare_for_run(ctx) for toolset in self.toolsets]) + combined_for_run = CombinedToolset(toolsets_for_run) + return RunToolset(combined_for_run, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [tool_def for toolset in self.toolsets for tool_def in toolset.tool_defs] + + @property + def tool_names(self) -> list[str]: + return list(self._toolset_per_tool_name.keys()) + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self._toolset_for_tool_name(name)._get_tool_args_validator(ctx, name) + + def validate_tool_args( + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False + ) -> dict[str, Any]: + return self._toolset_for_tool_name(name).validate_tool_args(ctx, name, args, allow_partial) + + def _max_retries_for_tool(self, name: str) -> int: + return self._toolset_for_tool_name(name)._max_retries_for_tool(name) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs) + + def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]: + try: + return self._toolset_per_tool_name[name] + except KeyError as e: + raise ValueError(f'Tool {name!r} not found in any toolset') from e diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py new file mode 100644 index 000000000..b6b1f8806 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import Any + +from pydantic_core import SchemaValidator + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from . import AbstractToolset +from ._run import RunToolset + + +class DeferredToolset(AbstractToolset[AgentDepsT]): + """A toolset that holds deferred tool.""" + + _tool_defs: list[ToolDefinition] + + def __init__(self, tool_defs: list[ToolDefinition]): + self._tool_defs = tool_defs + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + return RunToolset(self, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [replace(tool_def, kind='deferred') for tool_def in self._tool_defs] + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + raise NotImplementedError('Deferred tools cannot be validated') + + def _max_retries_for_tool(self, name: str) -> int: + raise NotImplementedError('Deferred tools cannot be retried') + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + raise NotImplementedError('Deferred tools cannot be called') diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py new file mode 100644 index 000000000..336b18a39 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from . import AbstractToolset +from ._individually_prepared import IndividuallyPreparedToolset + + +@dataclass(init=False) +class FilteredToolset(IndividuallyPreparedToolset[AgentDepsT]): + """A toolset that filters the tools it contains using a filter function.""" + + def __init__( + self, + toolset: AbstractToolset[AgentDepsT], + filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool], + ): + async def filter_tool_def(ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: + return tool_def if filter_func(ctx, tool_def) else None + + super().__init__(toolset, filter_tool_def) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py new file mode 100644 index 000000000..fbc60f8b0 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field, replace +from typing import Any, Callable, overload + +from pydantic.json_schema import GenerateJsonSchema +from pydantic_core import SchemaValidator + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ( + DocstringFormat, + GenerateToolJsonSchema, + Tool, + ToolDefinition, + ToolFuncEither, + ToolParams, + ToolPrepareFunc, +) +from . import AbstractToolset +from ._individually_prepared import IndividuallyPreparedToolset +from ._run import RunToolset + + +@dataclass(init=False) +class FunctionToolset(AbstractToolset[AgentDepsT]): + """A toolset that lets Python functions be used as tools.""" + + max_retries: int = field(default=1) + tools: dict[str, Tool[Any]] = field(default_factory=dict) + + def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): + self.max_retries = max_retries + self.tools = {} + for tool in tools: + if isinstance(tool, Tool): + self.register_tool(tool) + else: + self.register_function(tool) + + @overload + def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ... + + @overload + def tool( + self, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ... + + def tool( + self, + func: ToolFuncEither[AgentDepsT, ToolParams] | None = None, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Any: + """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. + + Can decorate a sync or async functions. + + The docstring is inspected to extract both the tool description and description of each parameter, + [learn more](../tools.md#function-tools-and-schema). + + We can't add overloads for every possible signature of tool, since the return type is a recursive union + so the signature of functions decorated with `@agent.tool` is obscured. + + Example: + ```python + from pydantic_ai import Agent, RunContext + + agent = Agent('test', deps_type=int) + + @agent.tool + def foobar(ctx: RunContext[int], x: int) -> int: + return ctx.deps + x + + @agent.tool(retries=2) + async def spam(ctx: RunContext[str], y: float) -> float: + return ctx.deps + y + + result = agent.run_sync('foobar', deps=1) + print(result.output) + #> {"foobar":1,"spam":1.0} + ``` + + Args: + func: The tool function to register. + name: The name of the tool, defaults to the function name. + retries: The number of retries to allow for this tool, defaults to the agent's default retries, + which defaults to 1. + prepare: custom method to prepare the tool definition for each step, return `None` to omit this + tool from a given step. This is useful if you want to customise a tool at call time, + or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. + docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. + Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. + require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. + schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. + strict: Whether to enforce JSON schema compliance (only affects OpenAI). + See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. + """ + if func is None: + + def tool_decorator( + func_: ToolFuncEither[AgentDepsT, ToolParams], + ) -> ToolFuncEither[AgentDepsT, ToolParams]: + # noinspection PyTypeChecker + self.register_function( + func_, + None, + name, + retries, + prepare, + docstring_format, + require_parameter_descriptions, + schema_generator, + strict, + ) + return func_ + + return tool_decorator + else: + # noinspection PyTypeChecker + self.register_function( + func, + None, + name, + retries, + prepare, + docstring_format, + require_parameter_descriptions, + schema_generator, + strict, + ) + return func + + def register_function( + self, + func: ToolFuncEither[AgentDepsT, ToolParams], + takes_ctx: bool | None = None, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> None: + """Register a function as a tool.""" + tool = Tool[AgentDepsT]( + func, + takes_ctx=takes_ctx, + name=name, + max_retries=retries, + prepare=prepare, + docstring_format=docstring_format, + require_parameter_descriptions=require_parameter_descriptions, + schema_generator=schema_generator, + strict=strict, + ) + self.register_tool(tool) + + def register_tool(self, tool: Tool[AgentDepsT]) -> None: + if tool.name in self.tools: + raise UserError(f'Tool name conflicts with existing tool: {tool.name!r}') + if tool.max_retries is None: + tool.max_retries = self.max_retries + self.tools[tool.name] = tool + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + self_for_run = RunToolset(self, ctx) + prepared_for_run = await IndividuallyPreparedToolset(self_for_run, self._prepare_tool_def).prepare_for_run(ctx) + return RunToolset(prepared_for_run, ctx, original=self) + + async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: + tool_name = tool_def.name + ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0)) + return await self.tools[tool_name].prepare_tool_def(ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [tool.tool_def for tool in self.tools.values()] + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.tools[name].function_schema.validator + + def _max_retries_for_tool(self, name: str) -> int: + tool = self.tools[name] + return tool.max_retries if tool.max_retries is not None else self.max_retries + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.tools[name].function_schema.call(tool_args, ctx) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py new file mode 100644 index 000000000..9210746ae --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any + +from pydantic_core import SchemaValidator + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from ._run import RunToolset +from .wrapper import WrapperToolset + + +@dataclass +class PrefixedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prefixes the names of the tools it contains.""" + + prefix: str + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + prefixed_for_run = PrefixedToolset(wrapped_for_run, self.prefix) + return RunToolset(prefixed_for_run, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [replace(tool_def, name=self._prefixed_tool_name(tool_def.name)) for tool_def in super().tool_defs] + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return super()._get_tool_args_validator(ctx, self._unprefixed_tool_name(name)) + + def _max_retries_for_tool(self, name: str) -> int: + return super()._max_retries_for_tool(self._unprefixed_tool_name(name)) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await super().call_tool(ctx, self._unprefixed_tool_name(name), tool_args, *args, **kwargs) + + def _prefixed_tool_name(self, tool_name: str) -> str: + return f'{self.prefix}_{tool_name}' + + def _unprefixed_tool_name(self, tool_name: str) -> str: + full_prefix = f'{self.prefix}_' + if not tool_name.startswith(full_prefix): + raise ValueError(f"Tool name '{tool_name}' does not start with prefix '{full_prefix}'") + return tool_name[len(full_prefix) :] diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py new file mode 100644 index 000000000..f35d7154d --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ToolsPrepareFunc +from ._run import RunToolset +from .wrapper import WrapperToolset + + +@dataclass +class PreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a prepare function.""" + + prepare_func: ToolsPrepareFunc[AgentDepsT] + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + original_tool_defs = wrapped_for_run.tool_defs + prepared_tool_defs = await self.prepare_func(ctx, original_tool_defs) or [] + + original_tool_names = {tool_def.name for tool_def in original_tool_defs} + prepared_tool_names = {tool_def.name for tool_def in prepared_tool_defs} + if len(prepared_tool_names - original_tool_names) > 0: + raise UserError('Prepare function is not allowed to change tool names or add new tools.') + + prepared_for_run = PreparedToolset(wrapped_for_run, self.prepare_func) + return RunToolset(prepared_for_run, ctx, prepared_tool_defs) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/processed.py b/pydantic_ai_slim/pydantic_ai/toolsets/processed.py new file mode 100644 index 000000000..c63854f7b --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/processed.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from collections.abc import Awaitable +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Protocol + +from .._run_context import AgentDepsT, RunContext +from ._run import RunToolset +from .wrapper import WrapperToolset + + +class CallToolFunc(Protocol): + """A function protocol that represents a tool call.""" + + def __call__(self, name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any) -> Awaitable[Any]: ... + + +ToolProcessFunc = Callable[ + [ + RunContext[AgentDepsT], + CallToolFunc, + str, + dict[str, Any], + ], + Awaitable[Any], +] + + +@dataclass +class ProcessedToolset(WrapperToolset[AgentDepsT]): + """A toolset that lets the tool call arguments and return value be customized using a wrapper function.""" + + process: ToolProcessFunc[AgentDepsT] + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + processed = ProcessedToolset(wrapped_for_run, self.process) + return RunToolset(processed, ctx) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.process(ctx, partial(self.wrapped.call_tool, ctx), name, tool_args, *args, **kwargs) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py new file mode 100644 index 000000000..354de8ebc --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from types import TracebackType +from typing import TYPE_CHECKING, Any + +from pydantic_core import SchemaValidator +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from . import AbstractToolset + +if TYPE_CHECKING: + from ._run import RunToolset + + +@dataclass +class WrapperToolset(AbstractToolset[AgentDepsT], ABC): + """A toolset that wraps another toolset and delegates to it.""" + + wrapped: AbstractToolset[AgentDepsT] + + @property + def name(self) -> str: + return self.wrapped.name + + @property + def tool_name_conflict_hint(self) -> str: + return self.wrapped.tool_name_conflict_hint + + async def __aenter__(self) -> Self: + await self.wrapped.__aenter__() + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + return await self.wrapped.__aexit__(exc_type, exc_value, traceback) + + @abstractmethod + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + raise NotImplementedError() + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self.wrapped.tool_defs + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.wrapped._get_tool_args_validator(ctx, name) + + def _max_retries_for_tool(self, name: str) -> int: + return self.wrapped._max_retries_for_tool(name) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) + + def __getattr__(self, item: str): + return getattr(self.wrapped, item) # pragma: no cover diff --git a/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml b/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml new file mode 100644 index 000000000..e33e36f96 --- /dev/null +++ b/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml @@ -0,0 +1,391 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '2501' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is 0 degrees Celsius in Fahrenheit? + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + properties: + foo: + type: string + required: + - foo + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1086' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '420' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{"celsius":0}' + name: celsius_to_fahrenheit + id: call_hS0oexgCNI6TneJuPPuwn9jQ + type: function + created: 1751491994 + id: chatcmpl-BozMoBhgfC5D8QBjkiOwz5OxxrwQK + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 18 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 268 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 286 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '2748' + content-type: + - application/json + cookie: + - __cf_bm=JOV7WG2Y48FZrZxdh0IZvA9mCj_ljIN3DhGMuC1pw6M-1751491995-1.0.1.1-zGPrLbzYx7y3iZT28xogbHO1KAIej60kPEwQ8ZxGMxv1r.ICtqI0T8WCnlyUccKfLSXB6ZTNQT05xCma8LSvq2pk4X2eEuSkYC1sPqbuLU8; + _cfuvid=LdoyX0uKYwM98NSSSvySlZAiJHCVHz_1krUGKbWmNHg-1751491995391-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is 0 degrees Celsius in Fahrenheit? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{"celsius":0}' + name: celsius_to_fahrenheit + id: call_hS0oexgCNI6TneJuPPuwn9jQ + type: function + - content: '32.0' + role: tool + tool_call_id: call_hS0oexgCNI6TneJuPPuwn9jQ + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + properties: + foo: + type: string + required: + - foo + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '849' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '520' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: 0 degrees Celsius is 32.0 degrees Fahrenheit. + refusal: null + role: assistant + created: 1751491998 + id: chatcmpl-BozMsevK8quJblNOyNCaDQpdtDwI5 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 300 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 312 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/ext/test_langchain.py b/tests/ext/test_langchain.py index 73e7cc050..926a22819 100644 --- a/tests/ext/test_langchain.py +++ b/tests/ext/test_langchain.py @@ -6,7 +6,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai import Agent -from pydantic_ai.ext.langchain import tool_from_langchain +from pydantic_ai.ext.langchain import LangChainToolset, tool_from_langchain @dataclass @@ -49,24 +49,26 @@ def get_input_jsonschema(self) -> JsonSchemaValue: } -def test_langchain_tool_conversion(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, +langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', }, - ) + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, +) + + +def test_langchain_tool_conversion(): pydantic_tool = tool_from_langchain(langchain_tool) agent = Agent('test', tools=[pydantic_tool], retries=7) @@ -74,6 +76,13 @@ def test_langchain_tool_conversion(): assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") +def test_langchain_toolset(): + toolset = LangChainToolset([langchain_tool]) + agent = Agent('test', toolsets=[toolset], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") + + def test_langchain_tool_no_additional_properties(): langchain_tool = SimulatedLangChainTool( name='file_search', diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 3891c5108..77857e882 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1700,7 +1700,7 @@ class CityLocation(BaseModel): agent = Agent(m, output_type=NativeOutput(CityLocation)) - with pytest.raises(UserError, match='Structured output is not supported by the model.'): + with pytest.raises(UserError, match='Native structured output is not supported by the model.'): await agent.run('What is the largest city in the user country?') diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 31635c080..02aafd259 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -4,6 +4,7 @@ import asyncio import dataclasses +import re from datetime import timezone from typing import Annotated, Any, Literal @@ -157,7 +158,7 @@ def validate_output(ctx: RunContext[None], output: OutputModel) -> OutputModel: call_count += 1 raise ModelRetry('Fail') - with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + with pytest.raises(UnexpectedModelBehavior, match=re.escape('Exceeded maximum retries (2) for output validation')): agent.run_sync('Hello', model=TestModel()) assert call_count == 3 @@ -200,7 +201,7 @@ class ResultModel(BaseModel): agent = Agent('test', output_type=ResultModel, retries=2) - with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(2\) for output validation'): agent.run_sync('Hello', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1})) diff --git a/tests/test_agent.py b/tests/test_agent.py index b85a9f5d8..52c913aa6 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,6 +1,7 @@ import json import re import sys +from dataclasses import dataclass from datetime import timezone from typing import Any, Callable, Union @@ -45,6 +46,7 @@ from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.function import FunctionToolset from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -396,6 +398,7 @@ def test_response_tuple(): 'type': 'object', }, outer_typed_dict_key='response', + kind='output', ) ] ) @@ -469,6 +472,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Foo', 'type': 'object', }, + kind='output', ) ] ) @@ -548,6 +552,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Foo', 'type': 'object', }, + kind='output', ), ToolDefinition( name='final_result_Bar', @@ -558,6 +563,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Bar', 'type': 'object', }, + kind='output', ), ] ) @@ -589,6 +595,7 @@ class MyOutput(BaseModel): 'title': 'MyOutput', 'type': 'object', }, + kind='output', ) ] ) @@ -635,6 +642,7 @@ class Bar(BaseModel): }, outer_typed_dict_key='response', strict=False, + kind='output', ) ] ) @@ -673,6 +681,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -712,6 +721,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -752,6 +762,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -793,6 +804,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -943,7 +955,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(upcase)]], ) def test_output_type_multiple_text_output(output_type: OutputSpec[str]): - with pytest.raises(UserError, match='Only one text output is allowed.'): + with pytest.raises(UserError, match='Only one `str` or `TextOutput` is allowed.'): Agent('test', output_type=output_type) @@ -989,6 +1001,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -1027,6 +1040,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -1065,6 +1079,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ), ToolDefinition( name='final_result_Weather', @@ -1075,6 +1090,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'title': 'Weather', 'type': 'object', }, + kind='output', ), ] ) @@ -1251,6 +1267,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ), ToolDefinition( name='return_weather', @@ -1261,6 +1278,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'title': 'Weather', 'type': 'object', }, + kind='output', ), ] ) @@ -1927,7 +1945,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: agent = Agent(FunctionModel(empty)) with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): agent.run_sync('Hello') assert messages == snapshot( [ @@ -2279,12 +2297,6 @@ def another_tool(y: int) -> int: tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), - RetryPromptPart( - tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", - timestamp=IsNow(tz=timezone.utc), - tool_call_id=IsStr(), - ), ToolReturnPart( tool_name='regular_tool', content=42, @@ -2294,6 +2306,12 @@ def another_tool(y: int) -> int: ToolReturnPart( tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ), + RetryPromptPart( + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", + tool_name='unknown_tool', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ), ] ), ] @@ -2357,16 +2375,16 @@ def another_tool(y: int) -> int: # pragma: no cover ModelRequest( parts=[ ToolReturnPart( - tool_name='regular_tool', - content='Tool not executed - a final result was already processed.', + tool_name='final_result', + content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( - tool_name='final_result', - content='Final result processed.', + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), - timestamp=IsNow(tz=timezone.utc), + timestamp=IsDatetime(), ), ToolReturnPart( tool_name='another_tool', @@ -2376,7 +2394,7 @@ def another_tool(y: int) -> int: # pragma: no cover ), RetryPromptPart( tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), @@ -2423,11 +2441,13 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: # Verify we got appropriate tool returns assert result.new_messages()[-1].parts == snapshot( [ - ToolReturnPart( + RetryPromptPart( + content=[ + {'type': 'missing', 'loc': ('value',), 'msg': 'Field required', 'input': {'bad_value': 'first'}} + ], tool_name='final_result', tool_call_id='first', - content='Output tool not used - result failed validation.', - timestamp=IsNow(tz=timezone.utc), + timestamp=IsDatetime(), ), ToolReturnPart( tool_name='final_result', @@ -3132,7 +3152,7 @@ def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: agent = Agent(model, output_type=NativeOutput(Foo)) - with pytest.raises(UserError, match='Structured output is not supported by the model.'): + with pytest.raises(UserError, match='Native structured output is not supported by the model.'): agent.run_sync('Hello') agent = Agent(model, output_type=ToolOutput(Foo)) @@ -3433,3 +3453,133 @@ def test_deprecated_kwargs_mixed_valid_invalid(): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) # Ignore the deprecation warning for result_tool_name Agent('test', result_tool_name='test', foo='value1', bar='value2') # type: ignore[call-arg] + + +def test_override_toolsets(): + foo_toolset = FunctionToolset() + + @foo_toolset.tool + def foo() -> str: + return 'Hello from foo' + + bar_toolset = FunctionToolset() + + @bar_toolset.tool + def bar() -> str: + return 'Hello from bar' + + available_tools: list[list[str]] = [] + + async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + nonlocal available_tools + available_tools.append([tool_def.name for tool_def in tool_defs]) + return tool_defs + + agent = Agent('test', toolsets=[foo_toolset], prepare_tools=prepare_tools) + + @agent.tool_plain + def baz() -> str: + return 'Hello from baz' + + result = agent.run_sync('Hello') + assert available_tools[-1] == snapshot(['baz', 'foo']) + assert result.output == snapshot('{"baz":"Hello from baz","foo":"Hello from foo"}') + + result = agent.run_sync('Hello', toolsets=[bar_toolset]) + assert available_tools[-1] == snapshot(['baz', 'bar']) + assert result.output == snapshot('{"baz":"Hello from baz","bar":"Hello from bar"}') + + result = agent.run_sync('Hello', toolsets=[]) + assert available_tools[-1] == snapshot(['baz']) + assert result.output == snapshot('{"baz":"Hello from baz"}') + + +def test_prepare_output_tools(): + @dataclass + class AgentDeps: + plan_presented: bool = False + + async def present_plan(ctx: RunContext[AgentDeps], plan: str) -> str: + """ + Present the plan to the user. + """ + ctx.deps.plan_presented = True + return plan + + async def run_sql(ctx: RunContext[AgentDeps], purpose: str, query: str) -> str: + """ + Run an SQL query. + """ + return 'SQL query executed successfully' + + async def only_if_plan_presented( + ctx: RunContext[AgentDeps], tool_defs: list[ToolDefinition] + ) -> list[ToolDefinition]: + return tool_defs if ctx.deps.plan_presented else [] + + agent = Agent( + model='test', + deps_type=AgentDeps, + tools=[present_plan], + output_type=[ToolOutput(run_sql, name='run_sql')], + prepare_output_tools=only_if_plan_presented, + ) + + result = agent.run_sync('Hello', deps=AgentDeps()) + assert result.output == snapshot('SQL query executed successfully') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='present_plan', + args={'plan': 'a'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + model_name='test', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='present_plan', + content='a', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='run_sql', + args={'purpose': 'a', 'query': 'a'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=52, response_tokens=12, total_tokens=64), + model_name='test', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='run_sql', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) diff --git a/tests/test_examples.py b/tests/test_examples.py index 6b4626f05..03ec1c342 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -17,11 +17,13 @@ import pytest from _pytest.mark import ParameterSet from devtools import debug +from pydantic_core import SchemaValidator, core_schema from pytest_examples import CodeExample, EvalExample, find_examples from pytest_mock import MockerFixture from rich.console import Console from pydantic_ai import ModelHTTPError +from pydantic_ai._run_context import RunContext from pydantic_ai._utils import group_by_temporal from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( @@ -37,6 +39,9 @@ from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets import AbstractToolset +from pydantic_ai.toolsets._run import RunToolset from .conftest import ClientWithHandler, TestEnv, try_import @@ -257,7 +262,7 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: raise ValueError(f'Unexpected prompt: {prompt}') -class MockMCPServer: +class MockMCPServer(AbstractToolset[Any]): is_running = True override_sampling_model = nullcontext @@ -267,10 +272,24 @@ async def __aenter__(self) -> MockMCPServer: async def __aexit__(self, *args: Any) -> None: pass - @staticmethod - async def list_tools() -> list[None]: + async def prepare_for_run(self, ctx: RunContext[Any]) -> RunToolset[Any]: + return RunToolset(self, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: return [] + def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> SchemaValidator: + return SchemaValidator(core_schema.any_schema()) # pragma: lax no cover + + def _max_retries_for_tool(self, name: str) -> int: + return 0 # pragma: lax no cover + + async def call_tool( + self, ctx: RunContext[Any], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return None # pragma: lax no cover + text_responses: dict[str, str | ToolCallPart] = { 'How many days between 2000-01-01 and 2025-03-18?': 'There are 9,208 days between January 1, 2000, and March 18, 2025.', diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 34aff7514..10e7a9ec1 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -289,6 +289,7 @@ async def my_ret(x: int) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ], 'output_mode': 'text', diff --git a/tests/test_mcp.py b/tests/test_mcp.py index b90b7135e..bb8181866 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,5 +1,7 @@ """Tests for the MCP (Model Context Protocol) server implementation.""" +from __future__ import annotations + import base64 import re from datetime import timezone @@ -23,8 +25,10 @@ ToolReturnPart, UserPromptPart, ) +from pydantic_ai.models import Model from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext +from pydantic_ai.toolsets.processed import CallToolFunc from pydantic_ai.usage import Usage from .conftest import IsDatetime, IsNow, IsStr, try_import @@ -34,7 +38,7 @@ from mcp.types import CreateMessageRequestParams, ImageContent, TextContent from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response - from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult + from pydantic_ai.mcp import MCPServerSSE, MCPServerStdio, ToolResult from pydantic_ai.models.google import GoogleModel from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.google import GoogleProvider @@ -48,22 +52,35 @@ @pytest.fixture -def agent(openai_api_key: str): - server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) - return Agent(model, mcp_servers=[server]) +def mcp_server() -> MCPServerStdio: + return MCPServerStdio('python', ['-m', 'tests.mcp_server']) + + +@pytest.fixture +def model(openai_api_key: str) -> Model: + return OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + +@pytest.fixture +def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: + return Agent(model, toolsets=[mcp_server]) + + +@pytest.fixture +def run_context(model: Model) -> RunContext[int]: + return RunContext(deps=0, model=model, usage=Usage(), sampling_model=model) -async def test_stdio_server(): +async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: - tools = await server.list_tools() + tools = (await server.prepare_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) assert tools[0].name == 'celsius_to_fahrenheit' assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') # Test calling the temperature conversion tool - result = await server.call_tool('celsius_to_fahrenheit', {'celsius': 0}) + result = await server.call_tool(run_context, 'celsius_to_fahrenheit', {'celsius': 0}) assert result == snapshot('32.0') @@ -74,38 +91,38 @@ async def test_reentrant_context_manager(): pass -async def test_stdio_server_with_tool_prefix(): +async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo') async with server: - tools = await server.list_tools() + tools = (await server.prepare_for_run(run_context)).tool_defs assert all(tool.name.startswith('foo_') for tool in tools) -async def test_stdio_server_with_cwd(): +async def test_stdio_server_with_cwd(run_context: RunContext[int]): test_dir = Path(__file__).parent server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: - tools = await server.list_tools() + tools = (await server.prepare_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) -async def test_process_tool_call() -> None: +async def test_process_tool_call(run_context: RunContext[int]) -> int: called: bool = False async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, - tool_name: str, + name: str, args: dict[str, Any], ) -> ToolResult: """A process_tool_call that sets a flag and sends deps as metadata.""" nonlocal called called = True - return await call_tool(tool_name, args, {'deps': ctx.deps}) + return await call_tool(name, args, metadata={'deps': ctx.deps}) server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call) async with server: - agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), mcp_servers=[server]) + agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), toolsets=[server]) result = await agent.run('Echo with deps set to 42', deps=42) assert result.output == snapshot('{"echo_deps":{"echo":"This is an echo message","deps":42}}') assert called, 'process_tool_call should have been called' @@ -134,7 +151,7 @@ def test_sse_server_with_header_and_timeout(): @pytest.mark.vcr() async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') assert result.output == snapshot('0 degrees Celsius is equal to 32 degrees Fahrenheit.') assert result.all_messages() == snapshot( @@ -211,11 +228,11 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_mcp_servers(): + async with agent: with pytest.raises( UserError, match=re.escape( - "MCP Server 'MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None)' defines a tool whose name conflicts with existing tool: 'get_none'. Consider using `tool_prefix` to avoid name conflicts." + "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None) defines a tool whose name conflicts with existing tool from Function toolset: 'get_none'. Consider setting `tool_prefix` to avoid name conflicts." ), ): await agent.run('Get me a conflict') @@ -226,7 +243,7 @@ async def test_agent_with_prefix_tool_name(openai_api_key: str): model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) agent = Agent( model, - mcp_servers=[server], + toolsets=[server], ) @agent.tool_plain @@ -234,43 +251,41 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_mcp_servers(): + async with agent: # This means that we passed the _prepare_request_parameters check and there is no conflict in the tool name with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): await agent.run('No conflict') -async def test_agent_with_server_not_running(openai_api_key: str): - server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) - agent = Agent(model, mcp_servers=[server]) - with pytest.raises(UserError, match='MCP server is not running'): - await agent.run('What is 0 degrees Celsius in Fahrenheit?') +@pytest.mark.vcr() +async def test_agent_with_server_not_running(agent: Agent, allow_model_requests: None): + result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') + assert result.output == snapshot('0 degrees Celsius is 32.0 degrees Fahrenheit.') -async def test_log_level_unset(): +async def test_log_level_unset(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) assert server.log_level is None async with server: - tools = await server.list_tools() + tools = (await server.prepare_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) assert tools[10].name == 'get_log_level' - result = await server.call_tool('get_log_level', {}) + result = await server.call_tool(run_context, 'get_log_level', {}) assert result == snapshot('unset') -async def test_log_level_set(): +async def test_log_level_set(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], log_level='info') assert server.log_level == 'info' async with server: - result = await server.call_tool('get_log_level', {}) + result = await server.call_tool(run_context, 'get_log_level', {}) assert result == snapshot('info') @pytest.mark.vcr() async def test_tool_returning_str(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('What is the weather in Mexico City?') assert result.output == snapshot( 'The weather in Mexico City is currently sunny with a temperature of 26 degrees Celsius.' @@ -349,7 +364,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_text_resource(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me the product name') assert result.output == snapshot('The product name is "PydanticAI".') assert result.all_messages() == snapshot( @@ -422,7 +437,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A @pytest.mark.vcr() async def test_tool_returning_image_resource(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me the image resource') assert result.output == snapshot( 'This is an image of a sliced kiwi with a vibrant green interior and black seeds.' @@ -505,7 +520,7 @@ async def test_tool_returning_audio_resource( allow_model_requests: None, agent: Agent, audio_content: BinaryContent, gemini_api_key: str ): model = GoogleModel('gemini-2.5-pro-preview-03-25', provider=GoogleProvider(api_key=gemini_api_key)) - async with agent.run_mcp_servers(): + async with agent: result = await agent.run("What's the content of the audio resource?", model=model) assert result.output == snapshot('The audio resource contains a voice saying "Hello, my name is Marcelo."') assert result.all_messages() == snapshot( @@ -556,7 +571,7 @@ async def test_tool_returning_audio_resource( @pytest.mark.vcr() async def test_tool_returning_image(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me an image') assert result.output == snapshot('Here is an image of a sliced kiwi on a white background.') assert result.all_messages() == snapshot( @@ -636,7 +651,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im @pytest.mark.vcr() async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me a dict, respond on one line') assert result.output == snapshot('{"foo":"bar","baz":123}') assert result.all_messages() == snapshot( @@ -703,7 +718,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_error(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me an error, pass False as a value, unless the tool tells you otherwise') assert result.output == snapshot( 'I called the tool with the correct parameter, and it returned: "This is not an error."' @@ -817,7 +832,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_none(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Call the none tool and say Hello') assert result.output == snapshot('Hello! How can I assist you today?') assert result.all_messages() == snapshot( @@ -884,7 +899,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_multiple_items(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me multiple items and summarize in one sentence') assert result.output == snapshot( 'The data includes two strings, a dictionary with a key-value pair, and an image of a sliced kiwi.' @@ -973,11 +988,11 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ) -async def test_client_sampling(): +async def test_client_sampling(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server.sampling_model = TestModel(custom_output_text='sampling model response') async with server: - result = await server.call_tool('use_sampling', {'foo': 'bar'}) + result = await server.call_tool(run_context, 'use_sampling', {'foo': 'bar'}) assert result == snapshot( { 'meta': None, @@ -989,27 +1004,27 @@ async def test_client_sampling(): ) -async def test_client_sampling_disabled(): +async def test_client_sampling_disabled(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], allow_sampling=False) server.sampling_model = TestModel(custom_output_text='sampling model response') async with server: with pytest.raises(ModelRetry, match='Error executing tool use_sampling: Sampling not supported'): - await server.call_tool('use_sampling', {'foo': 'bar'}) - + await server.call_tool(run_context, 'use_sampling', {'foo': 'bar'}) -async def test_mcp_server_raises_mcp_error(allow_model_requests: None, agent: Agent) -> None: - server = agent._mcp_servers[0] # pyright: ignore[reportPrivateUsage] +async def test_mcp_server_raises_mcp_error( + allow_model_requests: None, mcp_server: MCPServerStdio, agent: Agent, run_context: RunContext[int] +) -> None: mcp_error = McpError(error=ErrorData(code=400, message='Test MCP error conversion')) - async with agent.run_mcp_servers(): + async with agent: with patch.object( - server._client, # pyright: ignore[reportPrivateUsage] + mcp_server._client, # pyright: ignore[reportPrivateUsage] 'send_request', new=AsyncMock(side_effect=mcp_error), ): with pytest.raises(ModelRetry, match='Test MCP error conversion'): - await server.call_tool('test_tool', {}) + await mcp_server.call_tool(run_context, 'test_tool', {}) def test_map_from_mcp_params_model_request(): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 675e69000..cc5a0c4b3 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -5,6 +5,7 @@ import re from collections.abc import AsyncIterator from copy import deepcopy +from dataclasses import replace from datetime import timezone from typing import Any, Union @@ -12,14 +13,16 @@ from inline_snapshot import snapshot from pydantic import BaseModel -from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai import Agent, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( + FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent, ModelMessage, ModelRequest, ModelResponse, + PartStartEvent, RetryPromptPart, TextPart, ToolCallPart, @@ -28,8 +31,9 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import PromptedOutput, TextOutput +from pydantic_ai.output import DeferredToolCalls, PromptedOutput, TextOutput from pydantic_ai.result import AgentStream, FinalResult, Usage +from pydantic_ai.tools import ToolDefinition from pydantic_graph import End from .conftest import IsInt, IsNow, IsStr @@ -272,7 +276,7 @@ async def text_stream(_messages: list[ModelMessage], _: AgentInfo) -> AsyncItera agent = Agent(FunctionModel(stream_function=text_stream), output_type=tuple[str, str]) - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): async with agent.run_stream(''): pass @@ -407,7 +411,7 @@ async def ret_a(x: str) -> str: # pragma: no cover return x with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for output validation'): async with agent.run_stream('hello'): pass @@ -613,18 +617,18 @@ def another_tool(y: int) -> int: timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), - RetryPromptPart( - tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", - timestamp=IsNow(tz=timezone.utc), - tool_call_id=IsStr(), - ), ToolReturnPart( tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() ), ToolReturnPart( tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() ), + RetryPromptPart( + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", + tool_name='unknown_tool', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), ] ), ] @@ -712,15 +716,15 @@ def another_tool(y: int) -> int: # pragma: no cover ModelRequest( parts=[ ToolReturnPart( - tool_name='regular_tool', - content='Tool not executed - a final result was already processed.', + tool_name='final_result', + content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), part_kind='tool-return', ), ToolReturnPart( - tool_name='final_result', - content='Final result processed.', + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), part_kind='tool-return', @@ -733,10 +737,7 @@ def another_tool(y: int) -> int: # pragma: no cover part_kind='tool-return', ), RetryPromptPart( - content='Unknown tool name: ' - "'unknown_tool'. Available tools: " - 'regular_tool, another_tool, ' - 'final_result', + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), @@ -975,6 +976,13 @@ def known_tool(x: int) -> int: assert event_parts == snapshot( [ + FunctionToolCallEvent( + part=ToolCallPart( + tool_name='known_tool', + args={'x': 5}, + tool_call_id=IsStr(), + ) + ), FunctionToolCallEvent( part=ToolCallPart( tool_name='unknown_tool', @@ -984,14 +992,11 @@ def known_tool(x: int) -> int: ), FunctionToolResultEvent( result=RetryPromptPart( - content="Unknown tool name: 'unknown_tool'. Available tools: known_tool", + content="Unknown tool name: 'unknown_tool'. Available tools: 'known_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), - ) - ), - FunctionToolCallEvent( - part=ToolCallPart(tool_name='known_tool', args={'x': 5}, tool_call_id=IsStr()), + ), ), FunctionToolResultEvent( result=ToolReturnPart( @@ -999,13 +1004,6 @@ def known_tool(x: int) -> int: content=10, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), - ) - ), - FunctionToolCallEvent( - part=ToolCallPart( - tool_name='unknown_tool', - args={'arg': 'value'}, - tool_call_id=IsStr(), ), ), ] @@ -1027,15 +1025,15 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf agent = Agent(FunctionModel(call_final_result_with_bad_data), output_type=OutputType) - event_parts: list[Any] = [] + events: list[Any] = [] async with agent.iter('test') as agent_run: async for node in agent_run: if Agent.is_call_tools_node(node): async with node.stream(agent_run.ctx) as event_stream: async for event in event_stream: - event_parts.append(event) + events.append(event) - assert event_parts == snapshot( + assert events == snapshot( [ FunctionToolCallEvent( part=ToolCallPart( @@ -1045,9 +1043,16 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf ), ), FunctionToolResultEvent( - result=ToolReturnPart( + result=RetryPromptPart( + content=[ + { + 'type': 'missing', + 'loc': ('value',), + 'msg': 'Field required', + 'input': {'bad_value': 'invalid'}, + } + ], tool_name='final_result', - content='Output tool not used - result failed validation.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) @@ -1094,3 +1099,95 @@ def test_function_tool_event_tool_call_id_properties(): # The event should expose the same `tool_call_id` as the result part assert result_event.tool_call_id == return_part.tool_call_id == 'return_id_456' + + +async def test_deferred_tool(): + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) + + async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: + return replace(tool_def, kind='deferred') + + @agent.tool_plain(prepare=prepare_tool) + def my_tool(x: int) -> int: + return x + 1 + + async with agent.run_stream('Hello') as result: + assert not result.is_complete + output = await result.get_output() + assert output == snapshot( + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + kind='deferred', + ) + }, + ) + ) + assert result.is_complete + + +async def test_deferred_tool_iter(): + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) + + async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: + return replace(tool_def, kind='deferred') + + @agent.tool_plain(prepare=prepare_tool) + def my_tool(x: int) -> int: + return x + 1 + + outputs: list[str | DeferredToolCalls] = [] + events: list[Any] = [] + + async with agent.iter('test') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for event in stream: + events.append(event) + async for output in stream.stream_output(debounce_by=None): + outputs.append(output) + if agent.is_call_tools_node(node): + async with node.stream(run.ctx) as stream: + async for event in stream: + events.append(event) + + assert outputs == snapshot( + [ + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + kind='deferred', + ) + }, + ) + ] + ) + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr()), + ), + FinalResultEvent(tool_name=None, tool_call_id=None), + FunctionToolCallEvent(part=ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())), + ] + ) diff --git a/tests/test_tools.py b/tests/test_tools.py index c316a9cb7..fe582f717 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,11 +12,15 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext, Tool, UserError +from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import ToolOutput +from pydantic_ai.output import DeferredToolCalls, ToolOutput from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.deferred import DeferredToolset + +from .conftest import IsStr def test_tool_no_ctx(): @@ -105,6 +109,7 @@ def test_docstring_google(docstring_format: Literal['google', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) keys = list(json_schema.keys()) @@ -141,6 +146,7 @@ def test_docstring_sphinx(docstring_format: Literal['sphinx', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -180,6 +186,7 @@ def test_docstring_numpy(docstring_format: Literal['numpy', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -219,6 +226,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -256,6 +264,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -299,6 +308,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -330,6 +340,7 @@ def test_only_returns_type(): 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -352,6 +363,7 @@ def test_docstring_unknown(): 'parameters_json_schema': {'properties': {}, 'type': 'object'}, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -392,6 +404,7 @@ def test_docstring_google_no_body(docstring_format: Literal['google', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -425,6 +438,7 @@ def takes_just_model(model: Foo) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -467,6 +481,7 @@ def takes_just_model(model: Foo, z: int) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -486,15 +501,15 @@ def plain_tool(x: int) -> int: result = agent.run_sync('foobar') assert result.output == snapshot('{"plain_tool":1}') assert call_args == snapshot([0]) - assert agent._function_tools['plain_tool'].takes_ctx is False - assert agent._function_tools['plain_tool'].max_retries == 7 + assert agent._function_toolset.tools['plain_tool'].takes_ctx is False + assert agent._function_toolset.tools['plain_tool'].max_retries == 7 agent_infer = Agent('test', tools=[plain_tool], retries=7) result = agent_infer.run_sync('foobar') assert result.output == snapshot('{"plain_tool":1}') assert call_args == snapshot([0, 0]) - assert agent_infer._function_tools['plain_tool'].takes_ctx is False - assert agent_infer._function_tools['plain_tool'].max_retries == 7 + assert agent_infer._function_toolset.tools['plain_tool'].takes_ctx is False + assert agent_infer._function_toolset.tools['plain_tool'].max_retries == 7 def ctx_tool(ctx: RunContext[int], x: int) -> int: @@ -506,13 +521,13 @@ def test_init_tool_ctx(): agent = Agent('test', tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], deps_type=int, retries=7) result = agent.run_sync('foobar', deps=5) assert result.output == snapshot('{"ctx_tool":5}') - assert agent._function_tools['ctx_tool'].takes_ctx is True - assert agent._function_tools['ctx_tool'].max_retries == 3 + assert agent._function_toolset.tools['ctx_tool'].takes_ctx is True + assert agent._function_toolset.tools['ctx_tool'].max_retries == 3 agent_infer = Agent('test', tools=[ctx_tool], deps_type=int) result = agent_infer.run_sync('foobar', deps=6) assert result.output == snapshot('{"ctx_tool":6}') - assert agent_infer._function_tools['ctx_tool'].takes_ctx is True + assert agent_infer._function_toolset.tools['ctx_tool'].takes_ctx is True def test_repeat_tool_by_rename(): @@ -562,7 +577,7 @@ def foo(x: int, y: str) -> str: # pragma: no cover def bar(x: int, y: str) -> str: # pragma: no cover return f'{x} {y}' - with pytest.raises(UserError, match=r"Tool name conflicts with existing tool: 'bar'."): + with pytest.raises(UserError, match="Tool name conflicts with previously renamed tool: 'bar'."): agent.run_sync('') @@ -572,7 +587,10 @@ def test_tool_return_conflict(): # this is also okay Agent('test', tools=[ctx_tool], deps_type=int, output_type=int) # this raises an error - with pytest.raises(UserError, match="Tool name conflicts with output tool name: 'ctx_tool'"): + with pytest.raises( + UserError, + match="Function toolset defines a tool whose name conflicts with existing tool from Output toolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.", + ): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')) @@ -803,6 +821,7 @@ def test_suppress_griffe_logging(caplog: LogCaptureFixture): 'outer_typed_dict_key': None, 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'strict': None, + 'kind': 'function', } ) @@ -872,6 +891,7 @@ def my_tool_plain(*, a: int = 1, b: int) -> int: 'type': 'object', }, 'strict': None, + 'kind': 'function', }, { 'description': '', @@ -884,6 +904,7 @@ def my_tool_plain(*, a: int = 1, b: int) -> int: 'type': 'object', }, 'strict': None, + 'kind': 'function', }, ] ) @@ -968,6 +989,7 @@ def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = 'type': 'object', }, 'strict': None, + 'kind': 'function', }, { 'description': '', @@ -978,6 +1000,7 @@ def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = 'type': 'object', }, 'strict': None, + 'kind': 'function', }, ] ) @@ -1013,6 +1036,7 @@ def get_score(data: Data) -> int: ... # pragma: no branch }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -1044,7 +1068,7 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str: with agent.override(model=FunctionModel(get_json_schema)): result = agent.run_sync('', deps=21) json_schema = json.loads(result.output) - assert agent._function_tools['foobar'].strict is None + assert agent._function_toolset.tools['foobar'].strict is None assert json_schema['strict'] is True result = agent.run_sync('', deps=1) @@ -1071,8 +1095,8 @@ def function(*args: Any, **kwargs: Any) -> str: agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') assert result.output == snapshot('{"foobar":"I like being called like this"}') - assert agent._function_tools['foobar'].takes_ctx is False - assert agent._function_tools['foobar'].max_retries == 0 + assert agent._function_toolset.tools['foobar'].takes_ctx is False + assert agent._function_toolset.tools['foobar'].max_retries == 0 def test_function_tool_inconsistent_with_schema(): @@ -1118,5 +1142,101 @@ async def function(*args: Any, **kwargs: Any) -> str: agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') assert result.output == snapshot('{"foobar":"I like being called like this"}') - assert agent._function_tools['foobar'].takes_ctx is False - assert agent._function_tools['foobar'].max_retries == 0 + assert agent._function_toolset.tools['foobar'].takes_ctx is False + assert agent._function_toolset.tools['foobar'].max_retries == 0 + + +def test_tool_retries(): + prepare_tools_retries: list[int] = [] + prepare_retries: list[int] = [] + call_retries: list[int] = [] + + async def prepare_tool_defs( + ctx: RunContext[None], tool_defs: list[ToolDefinition] + ) -> Union[list[ToolDefinition], None]: + nonlocal prepare_tools_retries + retry = ctx.retries.get('infinite_retry_tool', 0) + prepare_tools_retries.append(retry) + return tool_defs + + agent = Agent(TestModel(), retries=3, prepare_tools=prepare_tool_defs) + + async def prepare_tool_def(ctx: RunContext[None], tool_def: ToolDefinition) -> Union[ToolDefinition, None]: + nonlocal prepare_retries + prepare_retries.append(ctx.retry) + return tool_def + + @agent.tool(retries=5, prepare=prepare_tool_def) + def infinite_retry_tool(ctx: RunContext[None]) -> int: + nonlocal call_retries + call_retries.append(ctx.retry) + raise ModelRetry('Please try again.') + + with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"): + agent.run_sync('Begin infinite retry loop!') + + # There are extra 0s here because the toolset is prepared once ahead of the graph run, before the user prompt part is added in. + assert prepare_tools_retries == [0, 0, 1, 2, 3, 4, 5] + assert prepare_retries == [0, 0, 1, 2, 3, 4, 5] + assert call_retries == [0, 1, 2, 3, 4, 5] + + +def test_deferred_tool(): + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls], toolsets=[deferred_toolset]) + + result = agent.run_sync('Hello') + assert result.output == snapshot( + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={ + 'type': 'object', + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + }, + kind='deferred', + ) + }, + ) + ) + + +def test_deferred_tool_with_output_type(): + class MyModel(BaseModel): + foo: str + + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(call_tools=[]), output_type=[MyModel, DeferredToolCalls], toolsets=[deferred_toolset]) + + result = agent.run_sync('Hello') + assert result.output == snapshot(MyModel(foo='a')) + + +def test_output_type_deferred_tool_calls_by_itself(): + with pytest.raises(UserError, match='At least one output type must be provided other than `DeferredToolCalls`.'): + Agent(TestModel(), output_type=DeferredToolCalls) + + +def test_output_type_empty(): + with pytest.raises(UserError, match='At least one output type must be provided.'): + Agent(TestModel(), output_type=[]) diff --git a/tests/test_toolset.py b/tests/test_toolset.py new file mode 100644 index 000000000..4122df05b --- /dev/null +++ b/tests/test_toolset.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import TypeVar + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai._run_context import RunContext +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.usage import Usage + +pytestmark = pytest.mark.anyio + +T = TypeVar('T') + + +def build_run_context(deps: T) -> RunContext[T]: + return RunContext( + deps=deps, + model=TestModel(), + usage=Usage(), + sampling_model=TestModel(), + prompt=None, + messages=[], + run_step=0, + ) + + +async def test_function_toolset_prepare_for_run(): + @dataclass + class PrefixDeps: + prefix: str | None = None + + context = build_run_context(PrefixDeps()) + toolset = FunctionToolset[PrefixDeps]() + + async def prepare_add_prefix(ctx: RunContext[PrefixDeps], tool_def: ToolDefinition) -> ToolDefinition | None: + if ctx.deps.prefix is None: + return tool_def + + return replace(tool_def, name=f'{ctx.deps.prefix}_{tool_def.name}') + + @toolset.tool(prepare=prepare_add_prefix) + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + assert toolset.tool_names == snapshot(['add']) + assert toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ) + ] + ) + assert await toolset.call_tool(context, 'add', {'a': 1, 'b': 2}) == 3 + + no_prefix_context = build_run_context(PrefixDeps()) + no_prefix_toolset = await toolset.prepare_for_run(no_prefix_context) + assert no_prefix_toolset.tool_names == toolset.tool_names + assert no_prefix_toolset.tool_defs == toolset.tool_defs + assert await no_prefix_toolset.call_tool(no_prefix_context, 'add', {'a': 1, 'b': 2}) == 3 + + foo_context = build_run_context(PrefixDeps(prefix='foo')) + foo_toolset = await toolset.prepare_for_run(foo_context) + assert foo_toolset.tool_names == snapshot(['foo_add']) + assert foo_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='foo_add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ) + ] + ) + assert await foo_toolset.call_tool(foo_context, 'foo_add', {'a': 1, 'b': 2}) == 3 + + @toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b # pragma: lax no cover + + assert foo_toolset.tool_names == snapshot(['foo_add']) + + bar_context = build_run_context(PrefixDeps(prefix='bar')) + bar_toolset = await toolset.prepare_for_run(bar_context) + assert bar_toolset.tool_names == snapshot(['bar_add', 'subtract']) + assert bar_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='bar_add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='subtract', + description='Subtract two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ] + ) + assert await bar_toolset.call_tool(bar_context, 'bar_add', {'a': 1, 'b': 2}) == 3 + + bar_foo_toolset = await foo_toolset.prepare_for_run(bar_context) + assert bar_foo_toolset == bar_toolset