Skip to content

feat: Add output function tracing #2191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ async def stream(
ctx.deps.output_schema,
ctx.deps.output_validators,
build_run_context(ctx),
_output.build_trace_context(ctx),
ctx.deps.usage_limits,
)
yield agent_stream
Expand Down Expand Up @@ -529,7 +530,8 @@ async def _handle_tool_calls(
if isinstance(output_schema, _output.ToolOutputSchema):
for call, output_tool in output_schema.find_tool(tool_calls):
try:
result_data = await output_tool.process(call, run_context)
trace_context = _output.build_trace_context(ctx)
result_data = await output_tool.process(call, run_context, trace_context)
result_data = await _validate_output(result_data, ctx, call)
except _output.ToolRetryError as e:
# TODO: Should only increment retry stuff once per node execution, not for each tool call
Expand Down Expand Up @@ -586,7 +588,8 @@ async def _handle_text_response(
try:
if isinstance(output_schema, _output.TextOutputSchema):
run_context = build_run_context(ctx)
result_data = await output_schema.process(text, run_context)
trace_context = _output.build_trace_context(ctx)
result_data = await output_schema.process(text, run_context, trace_context)
else:
m = _messages.RetryPromptPart(
content='Plain text responses are not permitted, please include your response in a tool call',
Expand Down
124 changes: 115 additions & 9 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations as _annotations

import dataclasses
import inspect
import json
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Iterable, Iterator, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload

from opentelemetry.trace import Tracer
from pydantic import TypeAdapter, ValidationError
from pydantic_core import SchemaValidator
from typing_extensions import TypedDict, TypeVar, assert_never

from pydantic_graph.nodes import GraphRunContext

from . import _function_schema, _utils, messages as _messages
from ._run_context import AgentDepsT, RunContext
from .exceptions import ModelRetry, UserError
Expand All @@ -29,6 +33,8 @@
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition

if TYPE_CHECKING:
from pydantic_ai._agent_graph import DepsT, GraphAgentDeps, GraphAgentState

from .profiles import ModelProfile

T = TypeVar('T')
Expand Down Expand Up @@ -66,6 +72,71 @@
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'


@dataclass(frozen=True)
class TraceContext:
"""A context for tracing output processing."""

tracer: Tracer
include_content: bool
call: _messages.ToolCallPart | None = None

Comment on lines +75 to +82
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've noticed a preference towards dataclasses in the code base. Just curious as to why this is the preferred choice? Is it primarily because of serialisation semantics?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a dataclass works nicely here and makes more sense than a 'plain' class.

a pydantic BaseModel would be better suited for data intended to be (de)serialized.

def with_call(self, call: _messages.ToolCallPart):
return dataclasses.replace(self, call=call)

async def execute_function_with_span(
self,
function_schema: _function_schema.FunctionSchema,
run_context: RunContext[AgentDepsT],
args: dict[str, Any] | Any,
call: _messages.ToolCallPart,
include_tool_call_id: bool = True,
) -> Any:
"""Execute a function call within a traced span, automatically recording the response."""
# Set up span attributes
attributes = {
'gen_ai.tool.name': call.tool_name,
'logfire.msg': f'running output function: {call.tool_name}',
}
if include_tool_call_id:
attributes['gen_ai.tool.call.id'] = call.tool_call_id
if self.include_content:
attributes['tool_arguments'] = call.args_as_json_str()
attributes['logfire.json_schema'] = json.dumps(
{
'type': 'object',
'properties': {
'tool_arguments': {'type': 'object'},
'tool_response': {'type': 'object'},
},
}
)

# Execute function within span
with self.tracer.start_as_current_span('running output function', attributes=attributes) as span:
output = await function_schema.call(args, run_context)

# Record response if content inclusion is enabled
if self.include_content and span.is_recording():
from .models.instrumented import InstrumentedModel

span.set_attribute(
'tool_response',
output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)),
)

return output


def build_trace_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> TraceContext:
"""Build a `TraceContext` from the current agent graph run context."""
return TraceContext(
tracer=ctx.deps.tracer,
include_content=(
ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content
),
)


class ToolRetryError(Exception):
"""Exception used to signal a `ToolRetry` message should be returned to the LLM."""

Expand Down Expand Up @@ -96,6 +167,7 @@ async def validate(
result: The result data after Pydantic validation the message content.
tool_call: The original tool call message, `None` if there was no tool call.
run_context: The current run context.
trace_context: The trace context to use for tracing the output processing.

Returns:
Result of either the validated result data (ok) or a retry message (Err).
Expand Down Expand Up @@ -349,6 +421,7 @@ async def process(
self,
text: str,
run_context: RunContext[AgentDepsT],
trace_context: TraceContext,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand All @@ -371,6 +444,7 @@ async def process(
self,
text: str,
run_context: RunContext[AgentDepsT],
trace_context: TraceContext,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand All @@ -379,6 +453,7 @@ async def process(
Args:
text: The output text to validate.
run_context: The current run context.
trace_context: The trace context to use for tracing the output processing.
allow_partial: If true, allow partial validation.
wrap_validation_errors: If true, wrap the validation errors in a retry message.

Expand All @@ -389,7 +464,7 @@ async def process(
return cast(OutputDataT, text)

return await self.processor.process(
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
)


Expand Down Expand Up @@ -417,6 +492,7 @@ async def process(
self,
text: str,
run_context: RunContext[AgentDepsT],
trace_context: TraceContext,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand All @@ -425,14 +501,15 @@ async def process(
Args:
text: The output text to validate.
run_context: The current run context.
trace_context: The trace context to use for tracing the output processing.
allow_partial: If true, allow partial validation.
wrap_validation_errors: If true, wrap the validation errors in a retry message.

Returns:
Either the validated output data (left) or a retry message (right).
"""
return await self.processor.process(
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
)


Expand Down Expand Up @@ -468,6 +545,7 @@ async def process(
self,
text: str,
run_context: RunContext[AgentDepsT],
trace_context: TraceContext,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand All @@ -476,6 +554,7 @@ async def process(
Args:
text: The output text to validate.
run_context: The current run context.
trace_context: The trace context to use for tracing the output processing.
allow_partial: If true, allow partial validation.
wrap_validation_errors: If true, wrap the validation errors in a retry message.

Expand All @@ -485,7 +564,7 @@ async def process(
text = _utils.strip_markdown_fences(text)

return await self.processor.process(
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
)


Expand Down Expand Up @@ -568,6 +647,7 @@ async def process(
self,
data: str,
run_context: RunContext[AgentDepsT],
trace_context: TraceContext,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand Down Expand Up @@ -637,6 +717,7 @@ async def process(
self,
data: str | dict[str, Any] | None,
run_context: RunContext[AgentDepsT],
trace_context: TraceContext,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand All @@ -645,6 +726,7 @@ async def process(
Args:
data: The output data to validate.
run_context: The current run context.
trace_context: The trace context to use for tracing the output processing.
allow_partial: If true, allow partial validation.
wrap_validation_errors: If true, wrap the validation errors in a retry message.

Expand All @@ -670,8 +752,18 @@ async def process(
output = output[k]

if self._function_schema:
# Wraps the output function call in an OpenTelemetry span.
if trace_context.call:
call = trace_context.call
include_tool_call_id = True
else:
function_name = getattr(self._function_schema.function, '__name__', 'output_function')
call = _messages.ToolCallPart(tool_name=function_name, args=data)
include_tool_call_id = False
try:
output = await self._function_schema.call(output, run_context)
output = await trace_context.execute_function_with_span(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible followup: dedupe this whole try/except

self._function_schema, run_context, output, call, include_tool_call_id
)
except ModelRetry as r:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
Expand Down Expand Up @@ -784,11 +876,12 @@ async def process(
self,
data: str | dict[str, Any] | None,
run_context: RunContext[AgentDepsT],
trace_context: TraceContext,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
union_object = await self._union_processor.process(
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
)

result = union_object.result
Expand All @@ -804,7 +897,7 @@ async def process(
raise

return await processor.process(
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
)


Expand Down Expand Up @@ -835,13 +928,20 @@ async def process(
self,
data: str,
run_context: RunContext[AgentDepsT],
trace_context: TraceContext,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
args = {self._str_argument_name: data}

# Wraps the output function call in an OpenTelemetry span.
# Note: PlainTextOutputProcessor is used for text responses (not tool calls),
# so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id
function_name = getattr(self._function_schema.function, '__name__', 'text_output_function')
call = _messages.ToolCallPart(tool_name=function_name, args=args)
try:
output = await self._function_schema.call(args, run_context)
output = await trace_context.execute_function_with_span(
self._function_schema, run_context, args, call, include_tool_call_id=False
)
except ModelRetry as r:
if wrap_validation_errors:
m = _messages.RetryPromptPart(
Expand Down Expand Up @@ -881,6 +981,7 @@ async def process(
self,
tool_call: _messages.ToolCallPart,
run_context: RunContext[AgentDepsT],
trace_context: TraceContext,
allow_partial: bool = False,
wrap_validation_errors: bool = True,
) -> OutputDataT:
Expand All @@ -889,6 +990,7 @@ async def process(
Args:
tool_call: The tool call from the LLM to validate.
run_context: The current run context.
trace_context: The trace context to use for tracing the output processing.
allow_partial: If true, allow partial validation.
wrap_validation_errors: If true, wrap the validation errors in a retry message.

Expand All @@ -897,7 +999,11 @@ async def process(
"""
try:
output = await self.processor.process(
tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False
tool_call.args,
run_context,
trace_context.with_call(tool_call),
allow_partial=allow_partial,
wrap_validation_errors=False,
)
except ValidationError as e:
if wrap_validation_errors:
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ async def on_complete() -> None:
streamed_response,
graph_ctx.deps.output_schema,
_agent_graph.build_run_context(graph_ctx),
_output.build_trace_context(graph_ctx),
graph_ctx.deps.output_validators,
final_result_details.tool_name,
on_complete,
Expand Down
Loading