-
Notifications
You must be signed in to change notification settings - Fork 1k
feat: Add output function tracing #2191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+852
−15
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
0ae3c18
Add span generation to ObjectOutputProcessor
bitnahian 539cfd3
Add working version with non-ToolCallPart call
bitnahian 0bb1ba3
Add working version
bitnahian 16b9f4d
Working tests
bitnahian 491c856
Fix tool_response serialisation in tracing
bitnahian e2b96a2
Add tests for TraceContext
bitnahian 2d736b1
Add tracing for TextOutputSchema
bitnahian 61b17de
Raise errors in unreachable code
bitnahian ef8c0af
simplify
alexmojaki eb7be2c
dedupe function schema call to one helper function
bitnahian 16e5aad
Merge branch 'main' into 2108-bitnahian
bitnahian c3ee3ab
Add list snapshot for retry test
bitnahian b210ee8
Make trace context non-nullable
bitnahian ae7cd7e
Update tests/test_logfire.py
bitnahian 2fe4642
Fix PR comments
bitnahian a3a25e8
Fix more with_calls
bitnahian 6fa702d
Fix more PR comments
bitnahian b2bb60e
Move statements outside try block
bitnahian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,20 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
import dataclasses | ||
import inspect | ||
import json | ||
from abc import ABC, abstractmethod | ||
from collections.abc import Awaitable, Iterable, Iterator, Sequence | ||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload | ||
|
||
from opentelemetry.trace import Tracer | ||
from pydantic import TypeAdapter, ValidationError | ||
from pydantic_core import SchemaValidator | ||
from typing_extensions import TypedDict, TypeVar, assert_never | ||
|
||
from pydantic_graph.nodes import GraphRunContext | ||
|
||
from . import _function_schema, _utils, messages as _messages | ||
from ._run_context import AgentDepsT, RunContext | ||
from .exceptions import ModelRetry, UserError | ||
|
@@ -29,6 +33,8 @@ | |
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition | ||
|
||
if TYPE_CHECKING: | ||
from pydantic_ai._agent_graph import DepsT, GraphAgentDeps, GraphAgentState | ||
|
||
from .profiles import ModelProfile | ||
|
||
T = TypeVar('T') | ||
|
@@ -66,6 +72,71 @@ | |
DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' | ||
|
||
|
||
@dataclass(frozen=True) | ||
class TraceContext: | ||
"""A context for tracing output processing.""" | ||
|
||
tracer: Tracer | ||
include_content: bool | ||
call: _messages.ToolCallPart | None = None | ||
|
||
def with_call(self, call: _messages.ToolCallPart): | ||
return dataclasses.replace(self, call=call) | ||
|
||
async def execute_function_with_span( | ||
self, | ||
function_schema: _function_schema.FunctionSchema, | ||
run_context: RunContext[AgentDepsT], | ||
args: dict[str, Any] | Any, | ||
call: _messages.ToolCallPart, | ||
include_tool_call_id: bool = True, | ||
) -> Any: | ||
"""Execute a function call within a traced span, automatically recording the response.""" | ||
# Set up span attributes | ||
attributes = { | ||
'gen_ai.tool.name': call.tool_name, | ||
'logfire.msg': f'running output function: {call.tool_name}', | ||
} | ||
if include_tool_call_id: | ||
attributes['gen_ai.tool.call.id'] = call.tool_call_id | ||
if self.include_content: | ||
attributes['tool_arguments'] = call.args_as_json_str() | ||
attributes['logfire.json_schema'] = json.dumps( | ||
{ | ||
'type': 'object', | ||
'properties': { | ||
'tool_arguments': {'type': 'object'}, | ||
'tool_response': {'type': 'object'}, | ||
}, | ||
} | ||
) | ||
|
||
# Execute function within span | ||
with self.tracer.start_as_current_span('running output function', attributes=attributes) as span: | ||
output = await function_schema.call(args, run_context) | ||
|
||
# Record response if content inclusion is enabled | ||
if self.include_content and span.is_recording(): | ||
from .models.instrumented import InstrumentedModel | ||
|
||
span.set_attribute( | ||
'tool_response', | ||
output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)), | ||
) | ||
|
||
return output | ||
|
||
|
||
def build_trace_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> TraceContext: | ||
"""Build a `TraceContext` from the current agent graph run context.""" | ||
return TraceContext( | ||
tracer=ctx.deps.tracer, | ||
include_content=( | ||
ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content | ||
), | ||
) | ||
|
||
|
||
class ToolRetryError(Exception): | ||
"""Exception used to signal a `ToolRetry` message should be returned to the LLM.""" | ||
|
||
|
@@ -96,6 +167,7 @@ async def validate( | |
result: The result data after Pydantic validation the message content. | ||
tool_call: The original tool call message, `None` if there was no tool call. | ||
run_context: The current run context. | ||
trace_context: The trace context to use for tracing the output processing. | ||
|
||
Returns: | ||
Result of either the validated result data (ok) or a retry message (Err). | ||
|
@@ -349,6 +421,7 @@ async def process( | |
self, | ||
text: str, | ||
run_context: RunContext[AgentDepsT], | ||
trace_context: TraceContext, | ||
allow_partial: bool = False, | ||
wrap_validation_errors: bool = True, | ||
) -> OutputDataT: | ||
|
@@ -371,6 +444,7 @@ async def process( | |
self, | ||
text: str, | ||
run_context: RunContext[AgentDepsT], | ||
trace_context: TraceContext, | ||
allow_partial: bool = False, | ||
wrap_validation_errors: bool = True, | ||
) -> OutputDataT: | ||
|
@@ -379,6 +453,7 @@ async def process( | |
Args: | ||
text: The output text to validate. | ||
run_context: The current run context. | ||
trace_context: The trace context to use for tracing the output processing. | ||
allow_partial: If true, allow partial validation. | ||
wrap_validation_errors: If true, wrap the validation errors in a retry message. | ||
|
||
|
@@ -389,7 +464,7 @@ async def process( | |
return cast(OutputDataT, text) | ||
|
||
return await self.processor.process( | ||
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
) | ||
|
||
|
||
|
@@ -417,6 +492,7 @@ async def process( | |
self, | ||
text: str, | ||
run_context: RunContext[AgentDepsT], | ||
trace_context: TraceContext, | ||
allow_partial: bool = False, | ||
wrap_validation_errors: bool = True, | ||
) -> OutputDataT: | ||
|
@@ -425,14 +501,15 @@ async def process( | |
Args: | ||
text: The output text to validate. | ||
run_context: The current run context. | ||
trace_context: The trace context to use for tracing the output processing. | ||
allow_partial: If true, allow partial validation. | ||
wrap_validation_errors: If true, wrap the validation errors in a retry message. | ||
|
||
Returns: | ||
Either the validated output data (left) or a retry message (right). | ||
""" | ||
return await self.processor.process( | ||
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
) | ||
|
||
|
||
|
@@ -468,6 +545,7 @@ async def process( | |
self, | ||
text: str, | ||
run_context: RunContext[AgentDepsT], | ||
trace_context: TraceContext, | ||
allow_partial: bool = False, | ||
wrap_validation_errors: bool = True, | ||
) -> OutputDataT: | ||
|
@@ -476,6 +554,7 @@ async def process( | |
Args: | ||
text: The output text to validate. | ||
run_context: The current run context. | ||
trace_context: The trace context to use for tracing the output processing. | ||
allow_partial: If true, allow partial validation. | ||
wrap_validation_errors: If true, wrap the validation errors in a retry message. | ||
|
||
|
@@ -485,7 +564,7 @@ async def process( | |
text = _utils.strip_markdown_fences(text) | ||
|
||
return await self.processor.process( | ||
text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
) | ||
|
||
|
||
|
@@ -568,6 +647,7 @@ async def process( | |
self, | ||
data: str, | ||
run_context: RunContext[AgentDepsT], | ||
trace_context: TraceContext, | ||
allow_partial: bool = False, | ||
wrap_validation_errors: bool = True, | ||
) -> OutputDataT: | ||
|
@@ -637,6 +717,7 @@ async def process( | |
self, | ||
data: str | dict[str, Any] | None, | ||
run_context: RunContext[AgentDepsT], | ||
trace_context: TraceContext, | ||
allow_partial: bool = False, | ||
wrap_validation_errors: bool = True, | ||
) -> OutputDataT: | ||
|
@@ -645,6 +726,7 @@ async def process( | |
Args: | ||
data: The output data to validate. | ||
run_context: The current run context. | ||
trace_context: The trace context to use for tracing the output processing. | ||
allow_partial: If true, allow partial validation. | ||
wrap_validation_errors: If true, wrap the validation errors in a retry message. | ||
|
||
|
@@ -670,8 +752,18 @@ async def process( | |
output = output[k] | ||
|
||
if self._function_schema: | ||
# Wraps the output function call in an OpenTelemetry span. | ||
if trace_context.call: | ||
call = trace_context.call | ||
include_tool_call_id = True | ||
else: | ||
function_name = getattr(self._function_schema.function, '__name__', 'output_function') | ||
call = _messages.ToolCallPart(tool_name=function_name, args=data) | ||
include_tool_call_id = False | ||
try: | ||
output = await self._function_schema.call(output, run_context) | ||
output = await trace_context.execute_function_with_span( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possible followup: dedupe this whole try/except |
||
self._function_schema, run_context, output, call, include_tool_call_id | ||
) | ||
except ModelRetry as r: | ||
if wrap_validation_errors: | ||
m = _messages.RetryPromptPart( | ||
|
@@ -784,11 +876,12 @@ async def process( | |
self, | ||
data: str | dict[str, Any] | None, | ||
run_context: RunContext[AgentDepsT], | ||
trace_context: TraceContext, | ||
allow_partial: bool = False, | ||
wrap_validation_errors: bool = True, | ||
) -> OutputDataT: | ||
union_object = await self._union_processor.process( | ||
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
) | ||
|
||
result = union_object.result | ||
|
@@ -804,7 +897,7 @@ async def process( | |
raise | ||
|
||
return await processor.process( | ||
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors | ||
) | ||
|
||
|
||
|
@@ -835,13 +928,20 @@ async def process( | |
self, | ||
data: str, | ||
run_context: RunContext[AgentDepsT], | ||
trace_context: TraceContext, | ||
allow_partial: bool = False, | ||
wrap_validation_errors: bool = True, | ||
) -> OutputDataT: | ||
args = {self._str_argument_name: data} | ||
|
||
# Wraps the output function call in an OpenTelemetry span. | ||
# Note: PlainTextOutputProcessor is used for text responses (not tool calls), | ||
# so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id | ||
function_name = getattr(self._function_schema.function, '__name__', 'text_output_function') | ||
call = _messages.ToolCallPart(tool_name=function_name, args=args) | ||
try: | ||
output = await self._function_schema.call(args, run_context) | ||
output = await trace_context.execute_function_with_span( | ||
self._function_schema, run_context, args, call, include_tool_call_id=False | ||
) | ||
except ModelRetry as r: | ||
if wrap_validation_errors: | ||
m = _messages.RetryPromptPart( | ||
|
@@ -881,6 +981,7 @@ async def process( | |
self, | ||
tool_call: _messages.ToolCallPart, | ||
run_context: RunContext[AgentDepsT], | ||
trace_context: TraceContext, | ||
allow_partial: bool = False, | ||
wrap_validation_errors: bool = True, | ||
) -> OutputDataT: | ||
|
@@ -889,6 +990,7 @@ async def process( | |
Args: | ||
tool_call: The tool call from the LLM to validate. | ||
run_context: The current run context. | ||
trace_context: The trace context to use for tracing the output processing. | ||
allow_partial: If true, allow partial validation. | ||
wrap_validation_errors: If true, wrap the validation errors in a retry message. | ||
|
||
|
@@ -897,7 +999,11 @@ async def process( | |
""" | ||
try: | ||
output = await self.processor.process( | ||
tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False | ||
tool_call.args, | ||
run_context, | ||
trace_context.with_call(tool_call), | ||
allow_partial=allow_partial, | ||
wrap_validation_errors=False, | ||
) | ||
except ValidationError as e: | ||
if wrap_validation_errors: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've noticed a preference towards dataclasses in the code base. Just curious as to why this is the preferred choice? Is it primarily because of serialisation semantics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a dataclass works nicely here and makes more sense than a 'plain' class.
a pydantic BaseModel would be better suited for data intended to be (de)serialized.