Skip to content

Upgrade fasta2a to A2A v0.2.3 and Enable Dependency Injection #2103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/a2a.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,37 @@ Since `app` is an ASGI application, it can be used with any ASGI server.
uvicorn agent_to_a2a:app --host 0.0.0.0 --port 8000
```

#### Using Agents with Dependencies

If your agent uses [dependencies](../agents.md#dependencies), you can provide a `deps_factory` function that creates dependencies from the A2A task metadata:

```python {title="agent_with_deps_to_a2a.py"}
from dataclasses import dataclass
from pydantic_ai import Agent, RunContext

@dataclass
class SupportDeps:
customer_id: int

support_agent = Agent(
'openai:gpt-4.1',
deps_type=SupportDeps,
instructions='You are a support agent.',
)

@support_agent.system_prompt
def add_customer_info(ctx: RunContext[SupportDeps]) -> str:
return f'The customer ID is {ctx.deps.customer_id}'

def create_deps(task):
"""Create dependencies from task metadata."""
metadata = task.get('metadata', {})
return SupportDeps(customer_id=metadata.get('customer_id', 0))

# Create A2A app with deps_factory
app = support_agent.to_a2a(deps_factory=create_deps)
```

Now when clients send messages with metadata, the agent will have access to the dependencies through the `deps_factory` function.

Since the goal of `to_a2a` is to be a convenience method, it accepts the same arguments as the [`FastA2A`][fasta2a.FastA2A] constructor.
39 changes: 39 additions & 0 deletions docs/examples/bank-support-a2a.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
Example showing how to expose the [bank support agent](bank-support.md) as an A2A server with dependency injection.

Demonstrates:

* Converting an existing agent to A2A
* Using `deps_factory` to provide customer context
* Passing metadata through A2A protocol

## Running the Example

With [dependencies installed and environment variables set](./index.md#usage), run:

```bash
# Start the A2A server
uvicorn pydantic_ai_examples.bank_support_a2a:app --reload

# In another terminal, send a request
curl -X POST http://localhost:8000/tasks.send \
-H "Content-Type: application/json" \
-d '{
"jsonrpc": "2.0",
"method": "tasks.send",
"params": {
"id": "test-task-1",
"message": {
"role": "user",
"parts": [{"type": "text", "text": "What is my balance?"}]
},
"metadata": {"customer_id": 123}
},
"id": "1"
}'
```

## Example Code

```python {title="bank_support_a2a.py"}
#! examples/pydantic_ai_examples/bank_support_a2a.py
```
3 changes: 2 additions & 1 deletion examples/pydantic_ai_examples/bank_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ async def customer_balance(cls, *, id: int, include_pending: bool) -> float:
else:
return 100.00
else:
raise ValueError('Customer not found')
return 42
# raise ValueError('Customer not found')


@dataclass
Expand Down
75 changes: 75 additions & 0 deletions examples/pydantic_ai_examples/bank_support_a2a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Bank support agent exposed as an A2A server.

Shows how to use deps_factory to provide customer context from task metadata.

Run the server:
python -m pydantic_ai_examples.bank_support_a2a
# or
uvicorn pydantic_ai_examples.bank_support_a2a:app --reload

Test with curl:
curl -X POST http://localhost:8000/ \
-H "Content-Type: application/json" \
-d '{
"jsonrpc": "2.0",
"method": "tasks/send",
"params": {
"id": "test-task-1",
"message": {
"role": "user",
"parts": [{"type": "text", "text": "What is my balance?"}]
},
"metadata": {"customer_id": 123}
},
"id": "1"
}'

Then get the result:
curl -X POST http://localhost:8000/ \
-H "Content-Type: application/json" \
-d '{
"jsonrpc": "2.0",
"method": "tasks/get",
"params": {"id": "test-task-1"},
"id": "2"
}'
"""

from fasta2a.schema import Task

from pydantic_ai_examples.bank_support import (
DatabaseConn,
SupportDependencies,
support_agent,
)


def create_deps(task: Task) -> SupportDependencies:
"""Create dependencies from A2A task metadata.

In a real application, you might:
- Validate the customer_id
- Look up authentication from a session token
- Connect to a real database with connection pooling
"""
metadata = task.get('metadata', {})
customer_id = metadata.get('customer_id', 0)

# In production, you'd validate the customer exists
# and the request is authorized
return SupportDependencies(customer_id=customer_id, db=DatabaseConn())


# Create the A2A application
app = support_agent.to_a2a(
deps_factory=create_deps,
name='Bank Support Agent',
description='AI support agent for banking customers',
)


if __name__ == '__main__':
# For development convenience
import uvicorn

uvicorn.run(app, host='0.0.0.0', port=8000)
34 changes: 30 additions & 4 deletions fasta2a/fasta2a/applications.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations as _annotations

import json
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager
from typing import Any

from sse_starlette import EventSourceResponse
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
Expand All @@ -21,6 +23,9 @@
a2a_request_ta,
a2a_response_ta,
agent_card_ta,
send_message_request_ta,
stream_event_ta,
stream_message_request_ta,
)
from .storage import Storage
from .task_manager import TaskManager
Expand Down Expand Up @@ -90,7 +95,7 @@ async def _agent_card_endpoint(self, request: Request) -> Response:
skills=self.skills,
default_input_modes=self.default_input_modes,
default_output_modes=self.default_output_modes,
capabilities=Capabilities(streaming=False, push_notifications=False, state_transition_history=False),
capabilities=Capabilities(streaming=True, push_notifications=False, state_transition_history=False),
authentication=Authentication(schemes=[]),
)
if self.description is not None:
Expand All @@ -105,7 +110,7 @@ async def _agent_run_endpoint(self, request: Request) -> Response:

Although the specification allows freedom of choice and implementation, I'm pretty sure about some decisions.

1. The server will always either send a "submitted" or a "failed" on `tasks/send`.
1. The server will always either send a "submitted" or a "failed" on `message/send`.
Never a "completed" on the first message.
2. There are three possible ends for the task:
2.1. The task was "completed" successfully.
Expand All @@ -116,8 +121,29 @@ async def _agent_run_endpoint(self, request: Request) -> Response:
data = await request.body()
a2a_request = a2a_request_ta.validate_json(data)

if a2a_request['method'] == 'tasks/send':
jsonrpc_response = await self.task_manager.send_task(a2a_request)
if a2a_request['method'] == 'message/send':
# Handle new message/send method
message_request = send_message_request_ta.validate_json(data)
jsonrpc_response = await self.task_manager.send_message(message_request)
elif a2a_request['method'] == 'message/stream':
# Parse the streaming request
stream_request = stream_message_request_ta.validate_json(data)

# Create an async generator wrapper that formats events as JSON-RPC responses
async def sse_generator():
request_id = stream_request.get('id')
async for event in self.task_manager.stream_message(stream_request):
# Serialize event to ensure proper camelCase conversion
event_dict = stream_event_ta.dump_python(event, mode='json', by_alias=True)

# Wrap in JSON-RPC response
jsonrpc_response = {'jsonrpc': '2.0', 'id': request_id, 'result': event_dict}

# Convert to JSON string
yield json.dumps(jsonrpc_response)

# Return SSE response
return EventSourceResponse(sse_generator())
elif a2a_request['method'] == 'tasks/get':
jsonrpc_response = await self.task_manager.get_task(a2a_request)
elif a2a_request['method'] == 'tasks/cancel':
Expand Down
82 changes: 81 additions & 1 deletion fasta2a/fasta2a/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from typing import Annotated, Any, Generic, Literal, TypeVar

import anyio
from anyio.streams.memory import MemoryObjectSendStream
from opentelemetry.trace import Span, get_current_span, get_tracer
from pydantic import Discriminator
from typing_extensions import Self, TypedDict

from .schema import TaskIdParams, TaskSendParams
from .schema import StreamEvent, TaskIdParams, TaskSendParams

tracer = get_tracer(__name__)

Expand Down Expand Up @@ -51,6 +52,26 @@ def receive_task_operations(self) -> AsyncIterator[TaskOperation]:
between the workers.
"""

@abstractmethod
async def send_stream_event(self, task_id: str, event: StreamEvent) -> None:
"""Send a streaming event from worker to subscribers.

This is used by workers to publish status updates, messages, and artifacts
during task execution. Events are forwarded to all active subscribers of
the given task_id.
"""
raise NotImplementedError('send_stream_event is not implemented yet.')

@abstractmethod
def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]:
"""Subscribe to streaming events for a specific task.

Returns an async iterator that yields events published by workers for the
given task_id. The iterator completes when a TaskStatusUpdateEvent with
final=True is received or the subscription is cancelled.
"""
raise NotImplementedError('subscribe_to_stream is not implemented yet.')


OperationT = TypeVar('OperationT')
ParamsT = TypeVar('ParamsT')
Expand All @@ -73,6 +94,12 @@ class _TaskOperation(TypedDict, Generic[OperationT, ParamsT]):
class InMemoryBroker(Broker):
"""A broker that schedules tasks in memory."""

def __init__(self):
# Event streams per task_id for pub/sub
self._event_subscribers: dict[str, list[MemoryObjectSendStream[StreamEvent]]] = {}
# Lock for thread-safe subscriber management
self._subscriber_lock = anyio.Lock()

async def __aenter__(self):
self.aexit_stack = AsyncExitStack()
await self.aexit_stack.__aenter__()
Expand All @@ -96,3 +123,56 @@ async def receive_task_operations(self) -> AsyncIterator[TaskOperation]:
"""Receive task operations from the broker."""
async for task_operation in self._read_stream:
yield task_operation

async def send_stream_event(self, task_id: str, event: StreamEvent) -> None:
"""Send a streaming event to all subscribers of a task."""
async with self._subscriber_lock:
subscribers = self._event_subscribers.get(task_id, [])
# Send to all active subscribers, removing any that are closed
active_subscribers: list[MemoryObjectSendStream[StreamEvent]] = []
for send_stream in subscribers:
try:
await send_stream.send(event)
active_subscribers.append(send_stream)
except anyio.ClosedResourceError:
# Subscriber disconnected, remove it
pass

# Update subscriber list with only active ones
if active_subscribers:
self._event_subscribers[task_id] = active_subscribers
elif task_id in self._event_subscribers:
# No active subscribers, clean up
del self._event_subscribers[task_id]

async def subscribe_to_stream(self, task_id: str) -> AsyncIterator[StreamEvent]:
"""Subscribe to events for a specific task."""
# Create a new stream for this subscriber
send_stream, receive_stream = anyio.create_memory_object_stream[StreamEvent](max_buffer_size=100)

# Register the subscriber
async with self._subscriber_lock:
if task_id not in self._event_subscribers:
self._event_subscribers[task_id] = []
self._event_subscribers[task_id].append(send_stream)

try:
# Yield events as they arrive
async with receive_stream:
async for event in receive_stream:
yield event
# Check if this is a final event
if isinstance(event, dict) and event.get('kind') == 'status-update' and event.get('final', False):
break
finally:
# Clean up subscription on exit
async with self._subscriber_lock:
if task_id in self._event_subscribers:
try:
self._event_subscribers[task_id].remove(send_stream)
if not self._event_subscribers[task_id]:
del self._event_subscribers[task_id]
except ValueError:
# Already removed
pass
await send_stream.aclose()
Loading
Loading