Skip to content

Commit b8fb766

Browse files
committed
Formatting
1 parent ad3fd7c commit b8fb766

File tree

3 files changed

+251
-193
lines changed

3 files changed

+251
-193
lines changed
+36-21
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,75 @@
11
"""A an example of serving a resilient agent using restate.dev"""
22

33
import os
4-
from fastapi import FastAPI
4+
55
import restate
6+
67
from agent import ReimbursementAgent
7-
from middleware import AgentMiddleware
8+
from common.types import (
9+
AgentCapabilities,
10+
AgentCard,
11+
AgentSkill,
12+
MissingAPIKeyError,
13+
)
814
from dotenv import load_dotenv
15+
from fastapi import FastAPI
16+
from middleware import AgentMiddleware
917

10-
from common.types import AgentCapabilities, AgentCard, AgentSkill, MissingAPIKeyError
1118

1219
load_dotenv()
1320

14-
RESTATE_HOST = os.getenv("RESTATE_HOST", "http://localhost:8080")
21+
RESTATE_HOST = os.getenv('RESTATE_HOST', 'http://localhost:8080')
1522

1623
AGENT_CARD = AgentCard(
17-
name="ReimbursementAgent",
18-
description="This agent handles the reimbursement process for the employees given the amount and purpose of the reimbursement.",
24+
name='ReimbursementAgent',
25+
description='This agent handles the reimbursement process for the employees given the amount and purpose of the reimbursement.',
1926
url=RESTATE_HOST,
20-
version="1.0.0",
27+
version='1.0.0',
2128
defaultInputModes=ReimbursementAgent.SUPPORTED_CONTENT_TYPES,
2229
defaultOutputModes=ReimbursementAgent.SUPPORTED_CONTENT_TYPES,
2330
capabilities=AgentCapabilities(streaming=False),
2431
skills=[
25-
AgentSkill(id="process_reimbursement",
26-
name="Process Reimbursement Tool",
27-
description="Helps with the reimbursement process for users given the amount and purpose of the reimbursement.",
28-
tags=["reimbursement"],
29-
examples=["Can you reimburse me $20 for my lunch with the clients?"])],
32+
AgentSkill(
33+
id='process_reimbursement',
34+
name='Process Reimbursement Tool',
35+
description='Helps with the reimbursement process for users given the amount and purpose of the reimbursement.',
36+
tags=['reimbursement'],
37+
examples=[
38+
'Can you reimburse me $20 for my lunch with the clients?'
39+
],
40+
)
41+
],
3042
)
3143

3244
REIMBURSEMENT_AGENT = AgentMiddleware(AGENT_CARD, ReimbursementAgent())
3345

3446
app = FastAPI()
3547

36-
@app.get("/.well-known/agent.json")
48+
49+
@app.get('/.well-known/agent.json')
3750
async def agent_json():
38-
"""serve the agent card"""
51+
"""Serve the agent card"""
3952
return REIMBURSEMENT_AGENT.agent_card_json
4053

41-
app.mount("/restate/v1", restate.app(REIMBURSEMENT_AGENT))
54+
55+
app.mount('/restate/v1', restate.app(REIMBURSEMENT_AGENT))
56+
4257

4358
def main():
4459
"""Serve the agent at a specified port using hypercorn."""
4560
import asyncio
61+
4662
import hypercorn
4763
import hypercorn.asyncio
4864

4965
if not os.getenv('GOOGLE_API_KEY'):
50-
raise MissingAPIKeyError(
51-
'GOOGLE_API_KEY environment variable not set.'
52-
)
66+
raise MissingAPIKeyError('GOOGLE_API_KEY environment variable not set.')
5367

54-
port = os.getenv("AGENT_PORT", "9080")
68+
port = os.getenv('AGENT_PORT', '9080')
5569
conf = hypercorn.Config()
56-
conf.bind = [f"0.0.0.0:{port}"]
70+
conf.bind = [f'0.0.0.0:{port}']
5771
asyncio.run(hypercorn.asyncio.serve(app, conf))
5872

59-
if __name__ == "__main__":
73+
74+
if __name__ == '__main__':
6075
main()

samples/python/agents/restate/agent.py

+71-64
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1-
"""
2-
An agent that handles reimbursement requests. Pretty much a copy of the
1+
"""An agent that handles reimbursement requests. Pretty much a copy of the
32
reimbursement agent from this repo, just made the tools a bit more interesting.
43
"""
54

6-
from typing import Any, Optional
75
import json
8-
import random
96
import logging
7+
import random
108

9+
from typing import Any, Optional
10+
11+
from agents.restate.middleware import AgentInvokeResult
12+
from common.types import TextPart
1113
from google.adk.agents.llm_agent import LlmAgent
12-
from google.adk.tools.tool_context import ToolContext
1314
from google.adk.artifacts import InMemoryArtifactService
1415
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
1516
from google.adk.runners import Runner
1617
from google.adk.sessions import InMemorySessionService
18+
from google.adk.tools.tool_context import ToolContext
1719
from google.genai import types
18-
from common.types import TextPart
1920

20-
from agents.restate.middleware import AgentInvokeResult
2121

2222
logger = logging.getLogger(__name__)
2323

@@ -31,8 +31,7 @@ def create_request_form(
3131
amount: Optional[str] = None,
3232
purpose: Optional[str] = None,
3333
) -> dict[str, Any]:
34-
"""
35-
Create a request form for the employee to fill out.
34+
"""Create a request form for the employee to fill out.
3635
3736
Args:
3837
date (str): The date of the request. Can be an empty string.
@@ -42,20 +41,20 @@ def create_request_form(
4241
Returns:
4342
dict[str, Any]: A dictionary containing the request form data.
4443
"""
45-
logger.info("Creating reimbursement request")
46-
request_id = "request_id_" + str(random.randint(1000000, 9999999))
44+
logger.info('Creating reimbursement request')
45+
request_id = 'request_id_' + str(random.randint(1000000, 9999999))
4746
request_ids.add(request_id)
4847
reimbursement = {
49-
"request_id": request_id,
50-
"date": "<transaction date>" if not date else date,
51-
"amount": "<transaction dollar amount>" if not amount else amount,
52-
"purpose": (
53-
"<business justification/purpose of the transaction>"
48+
'request_id': request_id,
49+
'date': '<transaction date>' if not date else date,
50+
'amount': '<transaction dollar amount>' if not amount else amount,
51+
'purpose': (
52+
'<business justification/purpose of the transaction>'
5453
if not purpose
5554
else purpose
5655
),
5756
}
58-
logger.info("Reimbursement request created: %s", json.dumps(reimbursement))
57+
logger.info('Reimbursement request created: %s', json.dumps(reimbursement))
5958

6059
return reimbursement
6160

@@ -65,8 +64,7 @@ def return_form(
6564
tool_context: ToolContext,
6665
instructions: Optional[str] = None,
6766
) -> dict[str, Any]:
68-
"""
69-
Returns a structured json object indicating a form to complete.
67+
"""Returns a structured json object indicating a form to complete.
7068
7169
Args:
7270
form_request (dict[str, Any]): The request form data.
@@ -76,64 +74,67 @@ def return_form(
7674
Returns:
7775
dict[str, Any]: A JSON dictionary for the form response.
7876
"""
79-
logger.info("Creating return form")
77+
logger.info('Creating return form')
8078
if isinstance(form_request, str):
8179
form_request = json.loads(form_request)
8280

8381
form_dict = {
84-
"type": "form",
85-
"form": {
86-
"type": "object",
87-
"properties": {
88-
"date": {
89-
"type": "string",
90-
"format": "date",
91-
"description": "Date of expense",
92-
"title": "Date",
82+
'type': 'form',
83+
'form': {
84+
'type': 'object',
85+
'properties': {
86+
'date': {
87+
'type': 'string',
88+
'format': 'date',
89+
'description': 'Date of expense',
90+
'title': 'Date',
9391
},
94-
"amount": {
95-
"type": "string",
96-
"format": "number",
97-
"description": "Amount of expense",
98-
"title": "Amount",
92+
'amount': {
93+
'type': 'string',
94+
'format': 'number',
95+
'description': 'Amount of expense',
96+
'title': 'Amount',
9997
},
100-
"purpose": {
101-
"type": "string",
102-
"description": "Purpose of expense",
103-
"title": "Purpose",
98+
'purpose': {
99+
'type': 'string',
100+
'description': 'Purpose of expense',
101+
'title': 'Purpose',
104102
},
105-
"request_id": {
106-
"type": "string",
107-
"description": "Request id",
108-
"title": "Request ID",
103+
'request_id': {
104+
'type': 'string',
105+
'description': 'Request id',
106+
'title': 'Request ID',
109107
},
110108
},
111-
"required": list(form_request.keys()),
109+
'required': list(form_request.keys()),
112110
},
113-
"form_data": form_request,
114-
"instructions": instructions,
111+
'form_data': form_request,
112+
'instructions': instructions,
115113
}
116-
logger.info("Return form created: %s", json.dumps(form_dict))
114+
logger.info('Return form created: %s', json.dumps(form_dict))
117115
return json.dumps(form_dict)
118116

119117

120118
async def reimburse(request_id: str) -> dict[str, Any]:
121119
"""Reimburse the amount of money to the employee for a given request_id."""
122-
logger.info("Starting reimbursement: %s", request_id)
120+
logger.info('Starting reimbursement: %s', request_id)
123121
if request_id not in request_ids:
124-
return {"request_id": request_id, "status": "Error: Invalid request_id."}
125-
logger.info("Reimbursement approved: %s", request_id)
126-
return {"request_id": request_id, "status": "approved"}
122+
return {
123+
'request_id': request_id,
124+
'status': 'Error: Invalid request_id.',
125+
}
126+
logger.info('Reimbursement approved: %s', request_id)
127+
return {'request_id': request_id, 'status': 'approved'}
127128

128129

129-
class ReimbursementAgent():
130+
class ReimbursementAgent:
130131
"""An agent that handles reimbursement requests."""
131132

132-
SUPPORTED_CONTENT_TYPES = ["text", "text/plain"]
133+
SUPPORTED_CONTENT_TYPES = ['text', 'text/plain']
133134

134135
def __init__(self):
135136
self._agent = self._build_agent()
136-
self._user_id = "remote_agent"
137+
self._user_id = 'remote_agent'
137138
self._runner = Runner(
138139
app_name=self._agent.name,
139140
agent=self._agent,
@@ -143,11 +144,15 @@ def __init__(self):
143144
)
144145

145146
async def invoke(self, query, session_id) -> AgentInvokeResult:
146-
logger.info("Invoking LLM")
147+
logger.info('Invoking LLM')
147148
session = self._runner.session_service.get_session(
148-
app_name=self._agent.name, user_id=self._user_id, session_id=session_id
149+
app_name=self._agent.name,
150+
user_id=self._user_id,
151+
session_id=session_id,
152+
)
153+
content = types.Content(
154+
role='user', parts=[types.Part.from_text(text=query)]
149155
)
150-
content = types.Content(role="user", parts=[types.Part.from_text(text=query)])
151156
if session is None:
152157
self._runner.session_service.create_session(
153158
app_name=self._agent.name,
@@ -162,19 +167,21 @@ async def invoke(self, query, session_id) -> AgentInvokeResult:
162167
session_id=session_id,
163168
new_message=content,
164169
):
165-
events.append(event)
170+
events.append(event)
166171

167-
logger.info("LLM response: %s", events)
172+
logger.info('LLM response: %s', events)
168173
if not events or not events[-1].content or not events[-1].content.parts:
169174
return AgentInvokeResult(
170-
parts=[TextPart(text="")],
175+
parts=[TextPart(text='')],
171176
require_user_input=False,
172177
is_task_complete=True,
173178
)
174179
return AgentInvokeResult(
175180
parts=[
176181
TextPart(
177-
text="\n".join([p.text for p in events[-1].content.parts if p.text])
182+
text='\n'.join(
183+
[p.text for p in events[-1].content.parts if p.text]
184+
)
178185
)
179186
],
180187
require_user_input=False,
@@ -184,11 +191,11 @@ async def invoke(self, query, session_id) -> AgentInvokeResult:
184191
def _build_agent(self) -> LlmAgent:
185192
"""Builds the LLM agent for the reimbursement agent."""
186193
return LlmAgent(
187-
model="gemini-2.0-flash-001",
188-
name="reimbursement_agent",
194+
model='gemini-2.0-flash-001',
195+
name='reimbursement_agent',
189196
description=(
190-
"This agent handles the reimbursement process for the employees"
191-
" given the amount and purpose of the reimbursement."
197+
'This agent handles the reimbursement process for the employees'
198+
' given the amount and purpose of the reimbursement.'
192199
),
193200
instruction="""
194201
You are an agent who handle the reimbursement process for employees.

0 commit comments

Comments
 (0)