1
- """
2
- An agent that handles reimbursement requests. Pretty much a copy of the
1
+ """An agent that handles reimbursement requests. Pretty much a copy of the
3
2
reimbursement agent from this repo, just made the tools a bit more interesting.
4
3
"""
5
4
6
- from typing import Any , Optional
7
5
import json
8
- import random
9
6
import logging
7
+ import random
10
8
9
+ from typing import Any , Optional
10
+
11
+ from agents .restate .middleware import AgentInvokeResult
12
+ from common .types import TextPart
11
13
from google .adk .agents .llm_agent import LlmAgent
12
- from google .adk .tools .tool_context import ToolContext
13
14
from google .adk .artifacts import InMemoryArtifactService
14
15
from google .adk .memory .in_memory_memory_service import InMemoryMemoryService
15
16
from google .adk .runners import Runner
16
17
from google .adk .sessions import InMemorySessionService
18
+ from google .adk .tools .tool_context import ToolContext
17
19
from google .genai import types
18
- from common .types import TextPart
19
20
20
- from agents .restate .middleware import AgentInvokeResult
21
21
22
22
logger = logging .getLogger (__name__ )
23
23
@@ -31,8 +31,7 @@ def create_request_form(
31
31
amount : Optional [str ] = None ,
32
32
purpose : Optional [str ] = None ,
33
33
) -> dict [str , Any ]:
34
- """
35
- Create a request form for the employee to fill out.
34
+ """Create a request form for the employee to fill out.
36
35
37
36
Args:
38
37
date (str): The date of the request. Can be an empty string.
@@ -42,20 +41,20 @@ def create_request_form(
42
41
Returns:
43
42
dict[str, Any]: A dictionary containing the request form data.
44
43
"""
45
- logger .info (" Creating reimbursement request" )
46
- request_id = " request_id_" + str (random .randint (1000000 , 9999999 ))
44
+ logger .info (' Creating reimbursement request' )
45
+ request_id = ' request_id_' + str (random .randint (1000000 , 9999999 ))
47
46
request_ids .add (request_id )
48
47
reimbursement = {
49
- " request_id" : request_id ,
50
- " date" : " <transaction date>" if not date else date ,
51
- " amount" : " <transaction dollar amount>" if not amount else amount ,
52
- " purpose" : (
53
- " <business justification/purpose of the transaction>"
48
+ ' request_id' : request_id ,
49
+ ' date' : ' <transaction date>' if not date else date ,
50
+ ' amount' : ' <transaction dollar amount>' if not amount else amount ,
51
+ ' purpose' : (
52
+ ' <business justification/purpose of the transaction>'
54
53
if not purpose
55
54
else purpose
56
55
),
57
56
}
58
- logger .info (" Reimbursement request created: %s" , json .dumps (reimbursement ))
57
+ logger .info (' Reimbursement request created: %s' , json .dumps (reimbursement ))
59
58
60
59
return reimbursement
61
60
@@ -65,8 +64,7 @@ def return_form(
65
64
tool_context : ToolContext ,
66
65
instructions : Optional [str ] = None ,
67
66
) -> dict [str , Any ]:
68
- """
69
- Returns a structured json object indicating a form to complete.
67
+ """Returns a structured json object indicating a form to complete.
70
68
71
69
Args:
72
70
form_request (dict[str, Any]): The request form data.
@@ -76,64 +74,67 @@ def return_form(
76
74
Returns:
77
75
dict[str, Any]: A JSON dictionary for the form response.
78
76
"""
79
- logger .info (" Creating return form" )
77
+ logger .info (' Creating return form' )
80
78
if isinstance (form_request , str ):
81
79
form_request = json .loads (form_request )
82
80
83
81
form_dict = {
84
- " type" : " form" ,
85
- " form" : {
86
- " type" : " object" ,
87
- " properties" : {
88
- " date" : {
89
- " type" : " string" ,
90
- " format" : " date" ,
91
- " description" : " Date of expense" ,
92
- " title" : " Date" ,
82
+ ' type' : ' form' ,
83
+ ' form' : {
84
+ ' type' : ' object' ,
85
+ ' properties' : {
86
+ ' date' : {
87
+ ' type' : ' string' ,
88
+ ' format' : ' date' ,
89
+ ' description' : ' Date of expense' ,
90
+ ' title' : ' Date' ,
93
91
},
94
- " amount" : {
95
- " type" : " string" ,
96
- " format" : " number" ,
97
- " description" : " Amount of expense" ,
98
- " title" : " Amount" ,
92
+ ' amount' : {
93
+ ' type' : ' string' ,
94
+ ' format' : ' number' ,
95
+ ' description' : ' Amount of expense' ,
96
+ ' title' : ' Amount' ,
99
97
},
100
- " purpose" : {
101
- " type" : " string" ,
102
- " description" : " Purpose of expense" ,
103
- " title" : " Purpose" ,
98
+ ' purpose' : {
99
+ ' type' : ' string' ,
100
+ ' description' : ' Purpose of expense' ,
101
+ ' title' : ' Purpose' ,
104
102
},
105
- " request_id" : {
106
- " type" : " string" ,
107
- " description" : " Request id" ,
108
- " title" : " Request ID" ,
103
+ ' request_id' : {
104
+ ' type' : ' string' ,
105
+ ' description' : ' Request id' ,
106
+ ' title' : ' Request ID' ,
109
107
},
110
108
},
111
- " required" : list (form_request .keys ()),
109
+ ' required' : list (form_request .keys ()),
112
110
},
113
- " form_data" : form_request ,
114
- " instructions" : instructions ,
111
+ ' form_data' : form_request ,
112
+ ' instructions' : instructions ,
115
113
}
116
- logger .info (" Return form created: %s" , json .dumps (form_dict ))
114
+ logger .info (' Return form created: %s' , json .dumps (form_dict ))
117
115
return json .dumps (form_dict )
118
116
119
117
120
118
async def reimburse (request_id : str ) -> dict [str , Any ]:
121
119
"""Reimburse the amount of money to the employee for a given request_id."""
122
- logger .info (" Starting reimbursement: %s" , request_id )
120
+ logger .info (' Starting reimbursement: %s' , request_id )
123
121
if request_id not in request_ids :
124
- return {"request_id" : request_id , "status" : "Error: Invalid request_id." }
125
- logger .info ("Reimbursement approved: %s" , request_id )
126
- return {"request_id" : request_id , "status" : "approved" }
122
+ return {
123
+ 'request_id' : request_id ,
124
+ 'status' : 'Error: Invalid request_id.' ,
125
+ }
126
+ logger .info ('Reimbursement approved: %s' , request_id )
127
+ return {'request_id' : request_id , 'status' : 'approved' }
127
128
128
129
129
- class ReimbursementAgent () :
130
+ class ReimbursementAgent :
130
131
"""An agent that handles reimbursement requests."""
131
132
132
- SUPPORTED_CONTENT_TYPES = [" text" , " text/plain" ]
133
+ SUPPORTED_CONTENT_TYPES = [' text' , ' text/plain' ]
133
134
134
135
def __init__ (self ):
135
136
self ._agent = self ._build_agent ()
136
- self ._user_id = " remote_agent"
137
+ self ._user_id = ' remote_agent'
137
138
self ._runner = Runner (
138
139
app_name = self ._agent .name ,
139
140
agent = self ._agent ,
@@ -143,11 +144,15 @@ def __init__(self):
143
144
)
144
145
145
146
async def invoke (self , query , session_id ) -> AgentInvokeResult :
146
- logger .info (" Invoking LLM" )
147
+ logger .info (' Invoking LLM' )
147
148
session = self ._runner .session_service .get_session (
148
- app_name = self ._agent .name , user_id = self ._user_id , session_id = session_id
149
+ app_name = self ._agent .name ,
150
+ user_id = self ._user_id ,
151
+ session_id = session_id ,
152
+ )
153
+ content = types .Content (
154
+ role = 'user' , parts = [types .Part .from_text (text = query )]
149
155
)
150
- content = types .Content (role = "user" , parts = [types .Part .from_text (text = query )])
151
156
if session is None :
152
157
self ._runner .session_service .create_session (
153
158
app_name = self ._agent .name ,
@@ -162,19 +167,21 @@ async def invoke(self, query, session_id) -> AgentInvokeResult:
162
167
session_id = session_id ,
163
168
new_message = content ,
164
169
):
165
- events .append (event )
170
+ events .append (event )
166
171
167
- logger .info (" LLM response: %s" , events )
172
+ logger .info (' LLM response: %s' , events )
168
173
if not events or not events [- 1 ].content or not events [- 1 ].content .parts :
169
174
return AgentInvokeResult (
170
- parts = [TextPart (text = "" )],
175
+ parts = [TextPart (text = '' )],
171
176
require_user_input = False ,
172
177
is_task_complete = True ,
173
178
)
174
179
return AgentInvokeResult (
175
180
parts = [
176
181
TextPart (
177
- text = "\n " .join ([p .text for p in events [- 1 ].content .parts if p .text ])
182
+ text = '\n ' .join (
183
+ [p .text for p in events [- 1 ].content .parts if p .text ]
184
+ )
178
185
)
179
186
],
180
187
require_user_input = False ,
@@ -184,11 +191,11 @@ async def invoke(self, query, session_id) -> AgentInvokeResult:
184
191
def _build_agent (self ) -> LlmAgent :
185
192
"""Builds the LLM agent for the reimbursement agent."""
186
193
return LlmAgent (
187
- model = " gemini-2.0-flash-001" ,
188
- name = " reimbursement_agent" ,
194
+ model = ' gemini-2.0-flash-001' ,
195
+ name = ' reimbursement_agent' ,
189
196
description = (
190
- " This agent handles the reimbursement process for the employees"
191
- " given the amount and purpose of the reimbursement."
197
+ ' This agent handles the reimbursement process for the employees'
198
+ ' given the amount and purpose of the reimbursement.'
192
199
),
193
200
instruction = """
194
201
You are an agent who handle the reimbursement process for employees.
0 commit comments