diff --git a/samples/python/agents/crewai/task_manager.py b/samples/python/agents/crewai/task_manager.py index ddc1bcfc..ec9f0919 100644 --- a/samples/python/agents/crewai/task_manager.py +++ b/samples/python/agents/crewai/task_manager.py @@ -3,7 +3,7 @@ import logging from typing import AsyncIterable from agent import ImageGenerationAgent -from common.server.task_manager import InMemoryTaskManager +from common.server.base_task_manager import BaseAgentTaskManager from common.server import utils from common.types import ( Artifact, @@ -24,22 +24,15 @@ logger = logging.getLogger(__name__) -class AgentTaskManager(InMemoryTaskManager): +class AgentTaskManager(BaseAgentTaskManager): """Agent Task Manager, handles task routing and response packing.""" def __init__(self, agent: ImageGenerationAgent): - super().__init__() - self.agent = agent - - async def _stream_generator( - self, request: SendTaskRequest - ) -> AsyncIterable[SendTaskResponse]: - raise NotImplementedError("Not implemented") + super().__init__(agent) async def on_send_task( self, request: SendTaskRequest - ) -> SendTaskResponse | AsyncIterable[SendTaskResponse]: - ## only support text output at the moment + ) -> SendTaskResponse: if not utils.are_modalities_compatible( request.params.acceptedOutputModes, ImageGenerationAgent.SUPPORTED_CONTENT_TYPES, @@ -53,7 +46,6 @@ async def on_send_task( task_send_params: TaskSendParams = request.params await self.upsert_task(task_send_params) - return await self._invoke(request) async def on_send_task_subscribe( @@ -64,6 +56,8 @@ async def on_send_task_subscribe( return error await self.upsert_task(request.params) + # streaming is not implemented for crewai + raise NotImplementedError("Not implemented") async def _update_store( self, task_id: str, status: TaskStatus, artifacts: list[Artifact] diff --git a/samples/python/agents/google_adk/task_manager.py b/samples/python/agents/google_adk/task_manager.py index 6d45c301..5324595d 100644 --- a/samples/python/agents/google_adk/task_manager.py +++ b/samples/python/agents/google_adk/task_manager.py @@ -18,14 +18,13 @@ SendTaskStreamingRequest, SendTaskStreamingResponse, ) -from common.server.task_manager import InMemoryTaskManager +from common.server.base_task_manager import BaseAgentTaskManager from google.genai import types import common.server.utils as utils from typing import Union import logging logger = logging.getLogger(__name__) -# TODO: Move this class (or these classes) to a common directory class AgentWithTaskManager(ABC): @abstractmethod @@ -93,11 +92,26 @@ async def stream(self, query, session_id) -> AsyncIterable[Dict[str, Any]]: "updates": self.get_processing_message(), } -class AgentTaskManager(InMemoryTaskManager): +class AgentTaskManager(BaseAgentTaskManager): def __init__(self, agent: AgentWithTaskManager): - super().__init__() - self.agent = agent + super().__init__(agent) + + async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: + error = self._validate_request(request) + if error: + return error + await self.upsert_task(request.params) + return await self._invoke(request) + + async def on_send_task_subscribe( + self, request: SendTaskStreamingRequest + ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: + error = self._validate_request(request) + if error: + return error + await self.upsert_task(request.params) + return self._stream_generator(request) async def _stream_generator( self, request: SendTaskStreamingRequest @@ -105,7 +119,7 @@ async def _stream_generator( task_send_params: TaskSendParams = request.params query = self._get_user_query(task_send_params) try: - async for item in self.agent.stream(query, task_send_params.sessionId): + async for item in self.stream(query, task_send_params.sessionId): is_task_complete = item["is_task_complete"] artifacts = None if not is_task_complete: @@ -163,6 +177,7 @@ async def _stream_generator( message="An error occurred while streaming the response" ), ) + def _validate_request( self, request: Union[SendTaskRequest, SendTaskStreamingRequest] ) -> None: @@ -176,20 +191,7 @@ def _validate_request( self.agent.SUPPORTED_CONTENT_TYPES, ) return utils.new_incompatible_types_error(request.id) - async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: - error = self._validate_request(request) - if error: - return error - await self.upsert_task(request.params) - return await self._invoke(request) - async def on_send_task_subscribe( - self, request: SendTaskStreamingRequest - ) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: - error = self._validate_request(request) - if error: - return error - await self.upsert_task(request.params) - return self._stream_generator(request) + async def _update_store( self, task_id: str, status: TaskStatus, artifacts: list[Artifact] ) -> Task: @@ -207,11 +209,12 @@ async def _update_store( task.artifacts = [] task.artifacts.extend(artifacts) return task + async def _invoke(self, request: SendTaskRequest) -> SendTaskResponse: task_send_params: TaskSendParams = request.params query = self._get_user_query(task_send_params) try: - result = self.agent.invoke(query, task_send_params.sessionId) + result = self.invoke(query, task_send_params.sessionId) except Exception as e: logger.error(f"Error invoking agent: {e}") raise ValueError(f"Error invoking agent: {e}") @@ -225,6 +228,7 @@ async def _invoke(self, request: SendTaskRequest) -> SendTaskResponse: [Artifact(parts=parts)], ) return SendTaskResponse(id=request.id, result=task) + def _get_user_query(self, task_send_params: TaskSendParams) -> str: part = task_send_params.message.parts[0] if not isinstance(part, TextPart): diff --git a/samples/python/agents/marvin/task_manager.py b/samples/python/agents/marvin/task_manager.py index 14dc8c2f..b2fc5962 100644 --- a/samples/python/agents/marvin/task_manager.py +++ b/samples/python/agents/marvin/task_manager.py @@ -6,7 +6,7 @@ import common.server.utils as utils from agents.marvin.agent import ExtractorAgent -from common.server.task_manager import InMemoryTaskManager +from common.server.base_task_manager import BaseAgentTaskManager from common.types import ( Artifact, DataPart, @@ -33,15 +33,13 @@ logger = logging.getLogger(__name__) -class AgentTaskManager(InMemoryTaskManager): +class AgentTaskManager(BaseAgentTaskManager): def __init__( self, agent: ExtractorAgent, notification_sender_auth: PushNotificationSenderAuth, ): - super().__init__() - self.agent = agent - self.notification_sender_auth = notification_sender_auth + super().__init__(agent, notification_sender_auth) def _parse_agent_outcome( self, agent_outcome: dict[str, Any] diff --git a/samples/python/agents/semantickernel/task_manager.py b/samples/python/agents/semantickernel/task_manager.py index d5053e1f..324ef91b 100644 --- a/samples/python/agents/semantickernel/task_manager.py +++ b/samples/python/agents/semantickernel/task_manager.py @@ -2,7 +2,7 @@ import logging from typing import AsyncIterable -from common.server.task_manager import InMemoryTaskManager +from common.server.base_task_manager import BaseAgentTaskManager from common.types import ( Artifact, InternalError, @@ -25,14 +25,12 @@ logger = logging.getLogger(__name__) -class TaskManager(InMemoryTaskManager): +class TaskManager(BaseAgentTaskManager): """A TaskManager used for the Semantic Kernel Agent sample.""" def __init__(self, notification_sender_auth: PushNotificationSenderAuth): """Initialize the TaskManager with a notification sender.""" - super().__init__() - self.agent = SemanticKernelTravelAgent() - self.notification_sender_auth = notification_sender_auth + super().__init__(SemanticKernelTravelAgent(), notification_sender_auth) async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: """A method to handle a task request. diff --git a/samples/python/common/server/base_task_manager.py b/samples/python/common/server/base_task_manager.py new file mode 100644 index 00000000..f225317b --- /dev/null +++ b/samples/python/common/server/base_task_manager.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod +from .task_manager import InMemoryTaskManager +from typing import AsyncIterable +from common.types import ( + SendTaskRequest, SendTaskResponse, SendTaskStreamingRequest, SendTaskStreamingResponse, JSONRPCResponse, TaskSendParams, TaskStatus, Artifact, Task +) + +class BaseAgentTaskManager(InMemoryTaskManager, ABC): + def __init__(self, agent, notification_sender_auth=None): + super().__init__() + self.agent = agent + self.notification_sender_auth = notification_sender_auth + + @abstractmethod + async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: + pass + + @abstractmethod + async def on_send_task_subscribe(self, request: SendTaskStreamingRequest) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: + pass + + def _validate_request(self, request): + # Default output type validation logic, override if needed + return None + + def _get_user_query(self, task_send_params: TaskSendParams) -> str: + # Default implementation for text-based query extraction + part = task_send_params.message.parts[0] + if hasattr(part, 'text'): + return part.text + raise ValueError("Only text parts are supported") + + async def _update_store(self, task_id: str, status: TaskStatus, artifacts: list[Artifact]) -> Task: + # Default task store update implementation + async with self.lock: + try: + task = self.tasks[task_id] + except KeyError as exc: + raise ValueError(f"Task {task_id} not found") from exc + task.status = status + if status.message is not None: + self.task_messages[task_id].append(status.message) + if artifacts is not None: + if task.artifacts is None: + task.artifacts = [] + task.artifacts.extend(artifacts) + return task \ No newline at end of file