Skip to content

fix: Complete the TODO items in samples/python/agents/google_adk/task_manager.py #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions samples/python/agents/crewai/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import AsyncIterable
from agent import ImageGenerationAgent
from common.server.task_manager import InMemoryTaskManager
from common.server.base_task_manager import BaseAgentTaskManager
from common.server import utils
from common.types import (
Artifact,
Expand All @@ -24,22 +24,15 @@
logger = logging.getLogger(__name__)


class AgentTaskManager(InMemoryTaskManager):
class AgentTaskManager(BaseAgentTaskManager):
"""Agent Task Manager, handles task routing and response packing."""

def __init__(self, agent: ImageGenerationAgent):
super().__init__()
self.agent = agent

async def _stream_generator(
self, request: SendTaskRequest
) -> AsyncIterable[SendTaskResponse]:
raise NotImplementedError("Not implemented")
super().__init__(agent)

async def on_send_task(
self, request: SendTaskRequest
) -> SendTaskResponse | AsyncIterable[SendTaskResponse]:
## only support text output at the moment
) -> SendTaskResponse:
if not utils.are_modalities_compatible(
request.params.acceptedOutputModes,
ImageGenerationAgent.SUPPORTED_CONTENT_TYPES,
Expand All @@ -53,7 +46,6 @@ async def on_send_task(

task_send_params: TaskSendParams = request.params
await self.upsert_task(task_send_params)

return await self._invoke(request)

async def on_send_task_subscribe(
Expand All @@ -64,6 +56,8 @@ async def on_send_task_subscribe(
return error

await self.upsert_task(request.params)
# crewai는 streaming 미구현
raise NotImplementedError("Not implemented")

async def _update_store(
self, task_id: str, status: TaskStatus, artifacts: list[Artifact]
Expand Down
46 changes: 25 additions & 21 deletions samples/python/agents/google_adk/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
SendTaskStreamingRequest,
SendTaskStreamingResponse,
)
from common.server.task_manager import InMemoryTaskManager
from common.server.base_task_manager import BaseAgentTaskManager
from google.genai import types
import common.server.utils as utils
from typing import Union
import logging
logger = logging.getLogger(__name__)

# TODO: Move this class (or these classes) to a common directory
class AgentWithTaskManager(ABC):

@abstractmethod
Expand Down Expand Up @@ -93,19 +92,34 @@ async def stream(self, query, session_id) -> AsyncIterable[Dict[str, Any]]:
"updates": self.get_processing_message(),
}

class AgentTaskManager(InMemoryTaskManager):
class AgentTaskManager(BaseAgentTaskManager):

def __init__(self, agent: AgentWithTaskManager):
super().__init__()
self.agent = agent
super().__init__(agent)

async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
error = self._validate_request(request)
if error:
return error
await self.upsert_task(request.params)
return await self._invoke(request)

async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
error = self._validate_request(request)
if error:
return error
await self.upsert_task(request.params)
return self._stream_generator(request)

async def _stream_generator(
self, request: SendTaskStreamingRequest
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
task_send_params: TaskSendParams = request.params
query = self._get_user_query(task_send_params)
try:
async for item in self.agent.stream(query, task_send_params.sessionId):
async for item in self.stream(query, task_send_params.sessionId):
is_task_complete = item["is_task_complete"]
artifacts = None
if not is_task_complete:
Expand Down Expand Up @@ -163,6 +177,7 @@ async def _stream_generator(
message="An error occurred while streaming the response"
),
)

def _validate_request(
self, request: Union[SendTaskRequest, SendTaskStreamingRequest]
) -> None:
Expand All @@ -176,20 +191,7 @@ def _validate_request(
self.agent.SUPPORTED_CONTENT_TYPES,
)
return utils.new_incompatible_types_error(request.id)
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
error = self._validate_request(request)
if error:
return error
await self.upsert_task(request.params)
return await self._invoke(request)
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
error = self._validate_request(request)
if error:
return error
await self.upsert_task(request.params)
return self._stream_generator(request)

async def _update_store(
self, task_id: str, status: TaskStatus, artifacts: list[Artifact]
) -> Task:
Expand All @@ -207,11 +209,12 @@ async def _update_store(
task.artifacts = []
task.artifacts.extend(artifacts)
return task

async def _invoke(self, request: SendTaskRequest) -> SendTaskResponse:
task_send_params: TaskSendParams = request.params
query = self._get_user_query(task_send_params)
try:
result = self.agent.invoke(query, task_send_params.sessionId)
result = self.invoke(query, task_send_params.sessionId)
except Exception as e:
logger.error(f"Error invoking agent: {e}")
raise ValueError(f"Error invoking agent: {e}")
Expand All @@ -225,6 +228,7 @@ async def _invoke(self, request: SendTaskRequest) -> SendTaskResponse:
[Artifact(parts=parts)],
)
return SendTaskResponse(id=request.id, result=task)

def _get_user_query(self, task_send_params: TaskSendParams) -> str:
part = task_send_params.message.parts[0]
if not isinstance(part, TextPart):
Expand Down
8 changes: 3 additions & 5 deletions samples/python/agents/marvin/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import common.server.utils as utils
from agents.marvin.agent import ExtractorAgent
from common.server.task_manager import InMemoryTaskManager
from common.server.base_task_manager import BaseAgentTaskManager
from common.types import (
Artifact,
DataPart,
Expand All @@ -33,15 +33,13 @@
logger = logging.getLogger(__name__)


class AgentTaskManager(InMemoryTaskManager):
class AgentTaskManager(BaseAgentTaskManager):
def __init__(
self,
agent: ExtractorAgent,
notification_sender_auth: PushNotificationSenderAuth,
):
super().__init__()
self.agent = agent
self.notification_sender_auth = notification_sender_auth
super().__init__(agent, notification_sender_auth)

def _parse_agent_outcome(
self, agent_outcome: dict[str, Any]
Expand Down
8 changes: 3 additions & 5 deletions samples/python/agents/semantickernel/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from typing import AsyncIterable

from common.server.task_manager import InMemoryTaskManager
from common.server.base_task_manager import BaseAgentTaskManager
from common.types import (
Artifact,
InternalError,
Expand All @@ -25,14 +25,12 @@
logger = logging.getLogger(__name__)


class TaskManager(InMemoryTaskManager):
class TaskManager(BaseAgentTaskManager):
"""A TaskManager used for the Semantic Kernel Agent sample."""

def __init__(self, notification_sender_auth: PushNotificationSenderAuth):
"""Initialize the TaskManager with a notification sender."""
super().__init__()
self.agent = SemanticKernelTravelAgent()
self.notification_sender_auth = notification_sender_auth
super().__init__(SemanticKernelTravelAgent(), notification_sender_auth)

async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
"""A method to handle a task request.
Expand Down
47 changes: 47 additions & 0 deletions samples/python/common/server/base_task_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from abc import ABC, abstractmethod
from .task_manager import InMemoryTaskManager
from typing import AsyncIterable
from common.types import (
SendTaskRequest, SendTaskResponse, SendTaskStreamingRequest, SendTaskStreamingResponse, JSONRPCResponse, TaskSendParams, TaskStatus, Artifact, Task
)

class BaseAgentTaskManager(InMemoryTaskManager, ABC):
def __init__(self, agent, notification_sender_auth=None):
super().__init__()
self.agent = agent
self.notification_sender_auth = notification_sender_auth

@abstractmethod
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
pass

@abstractmethod
async def on_send_task_subscribe(self, request: SendTaskStreamingRequest) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
pass

def _validate_request(self, request):
# Default output type validation logic, override if needed
return None

def _get_user_query(self, task_send_params: TaskSendParams) -> str:
# Default implementation for text-based query extraction
part = task_send_params.message.parts[0]
if hasattr(part, 'text'):
return part.text
raise ValueError("Only text parts are supported")

async def _update_store(self, task_id: str, status: TaskStatus, artifacts: list[Artifact]) -> Task:
# Default task store update implementation
async with self.lock:
try:
task = self.tasks[task_id]
except KeyError as exc:
raise ValueError(f"Task {task_id} not found") from exc
task.status = status
if status.message is not None:
self.task_messages[task_id].append(status.message)
if artifacts is not None:
if task.artifacts is None:
task.artifacts = []
task.artifacts.extend(artifacts)
return task