@@ -107,7 +107,7 @@ async def llm_generate(prompt: str, llm_type: str) -> str:
107
107
except Exception as e :
108
108
raise InternalServerError (str (e )) from e
109
109
110
- async def llm_generate_multimodal (prompt : str , llm_type : str , user_file_type : str ,
110
+ async def llm_generate_multimodal (prompt : str , llm_type : str , user_file_types : List [ str ] ,
111
111
user_file_bytes : bytes = None ,
112
112
user_file_urls : List [str ] = None ) -> str :
113
113
"""
@@ -116,6 +116,7 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_type: st
116
116
prompt: the text prompt to pass to the LLM
117
117
user_file_bytes: bytes of the file provided by the user
118
118
user_file_urls: list of URLs to include in context
119
+ user_file_types: list of mime times for files to include in context
119
120
llm_type: the type of LLM to use (default to gemini)
120
121
Returns:
121
122
the text response: str
@@ -145,7 +146,7 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_type: st
145
146
f"Vertex model { llm_type } needs to be multimodal" )
146
147
response = await google_llm_predict (prompt , is_chat , is_multimodal ,
147
148
google_llm , None , user_file_bytes ,
148
- user_file_urls , user_file_type )
149
+ user_file_urls , user_file_types )
149
150
else :
150
151
raise ResourceNotFoundException (f"Cannot find llm type '{ llm_type } '" )
151
152
@@ -159,9 +160,9 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_type: st
159
160
async def llm_chat (prompt : str , llm_type : str ,
160
161
user_chat : Optional [UserChat ] = None ,
161
162
user_query : Optional [UserQuery ] = None ,
162
- chat_file_type : str = None ,
163
- chat_file_urls : List [str ] = None ,
164
- chat_file_bytes : bytes = None ) -> str :
163
+ chat_file_types : Optional [ List [ str ]] = None ,
164
+ chat_file_urls : Optional [ List [str ] ] = None ,
165
+ chat_file_bytes : Optional [ bytes ] = None ) -> str :
165
166
"""
166
167
Send a prompt to a chat model and return string response.
167
168
Supports including a file in the chat context, either by URL or
@@ -174,7 +175,7 @@ async def llm_chat(prompt: str, llm_type: str,
174
175
user_query (optional): a user query to use for context
175
176
chat_file_bytes (bytes): bytes of file to include in chat context
176
177
chat_file_urls (List[str]): urls of files to include in chat context
177
- chat_file_type ( str): mime type of file to include in chat context
178
+ chat_file_types (List[ str] ): mime types of files to include in chat context
178
179
Returns:
179
180
the text response: str
180
181
"""
@@ -185,7 +186,7 @@ async def llm_chat(prompt: str, llm_type: str,
185
186
f" user_query=[{ user_query } ]"
186
187
f" chat_file_bytes=[{ chat_file_bytes_log } ]"
187
188
f" chat_file_urls=[{ chat_file_urls } ]"
188
- f" chat_file_type=[{ chat_file_type } ]" )
189
+ f" chat_file_type=[{ chat_file_types } ]" )
189
190
190
191
if llm_type not in get_model_config ().get_chat_llm_types ():
191
192
raise ResourceNotFoundException (f"Cannot find chat llm type '{ llm_type } '" )
@@ -198,7 +199,7 @@ async def llm_chat(prompt: str, llm_type: str,
198
199
"Must set only one of chat_file_bytes/chat_file_urls" )
199
200
if llm_type not in get_provider_models (PROVIDER_VERTEX ):
200
201
raise InternalServerError ("Chat files only supported for Vertex" )
201
- if chat_file_type is None :
202
+ if chat_file_types is None :
202
203
raise InternalServerError ("Mime type must be passed for chat file" )
203
204
is_multimodal = True
204
205
@@ -209,6 +210,8 @@ async def llm_chat(prompt: str, llm_type: str,
209
210
if user_chat is not None or user_query is not None :
210
211
context_prompt = get_context_prompt (
211
212
user_chat = user_chat , user_query = user_query )
213
+ # context_prompt includes only text (no images/video) from
214
+ # user_chat.history and user_query.history
212
215
prompt = context_prompt + "\n " + prompt
213
216
214
217
# check whether the context length exceeds the limit for the model
@@ -241,7 +244,7 @@ async def llm_chat(prompt: str, llm_type: str,
241
244
response = await google_llm_predict (prompt , is_chat , is_multimodal ,
242
245
google_llm , user_chat ,
243
246
chat_file_bytes ,
244
- chat_file_urls , chat_file_type )
247
+ chat_file_urls , chat_file_types )
245
248
elif llm_type in get_provider_models (PROVIDER_LANGCHAIN ):
246
249
response = await langchain_llm_generate (prompt , llm_type , user_chat )
247
250
return response
@@ -271,6 +274,7 @@ def get_context_prompt(user_chat=None,
271
274
prompt_list .append (f"Human input: { content } " )
272
275
elif UserChat .is_ai (entry ):
273
276
prompt_list .append (f"AI response: { content } " )
277
+ # prompt_list includes only text from user_chat.history
274
278
275
279
if user_query is not None :
276
280
history = user_query .history
@@ -280,6 +284,7 @@ def get_context_prompt(user_chat=None,
280
284
prompt_list .append (f"Human input: { content } " )
281
285
elif UserQuery .is_ai (entry ):
282
286
prompt_list .append (f"AI response: { content } " )
287
+ # prompt_list includes only text from user_query.history
283
288
284
289
context_prompt = "\n \n " .join (prompt_list )
285
290
@@ -294,6 +299,8 @@ def check_context_length(prompt, llm_type):
294
299
"""
295
300
# check if prompt exceeds context window length for model
296
301
# assume a constant relationship between tokens and chars
302
+ # TODO: Recalculate max_context_length for text prompt,
303
+ # subtracting out tokens used by non-text context (image, video, etc)
297
304
token_length = len (prompt ) / CHARS_PER_TOKEN
298
305
max_context_length = get_model_config_value (llm_type ,
299
306
KEY_MODEL_CONTEXT_LENGTH ,
@@ -489,9 +496,9 @@ async def model_garden_predict(prompt: str,
489
496
490
497
async def google_llm_predict (prompt : str , is_chat : bool , is_multimodal : bool ,
491
498
google_llm : str , user_chat = None ,
492
- user_file_bytes : bytes = None ,
493
- user_file_urls : List [str ]= None ,
494
- user_file_type : str = None ) -> str :
499
+ user_file_bytes : Optional [ bytes ] = None ,
500
+ user_file_urls : Optional [ List [str ] ]= None ,
501
+ user_file_types : Optional [ List [ str ]] = None ) -> str :
495
502
"""
496
503
Generate text with a Google multimodal LLM given a prompt.
497
504
Args:
@@ -502,7 +509,7 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
502
509
user_chat: chat history
503
510
user_file_bytes: the bytes of the file provided by the user
504
511
user_file_urls: list of urls of files provided by the user
505
- user_file_type: mime type of the file provided by the user
512
+ user_file_types: list of mime types of the files provided by the user
506
513
Returns:
507
514
the text response.
508
515
"""
@@ -513,7 +520,7 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
513
520
f" is_multimodal=[{ is_multimodal } ], google_llm=[{ google_llm } ],"
514
521
f" user_file_bytes=[{ user_file_bytes_log } ],"
515
522
f" user_file_urls=[{ user_file_urls } ],"
516
- f" user_file_type=[{ user_file_type } ]." )
523
+ f" user_file_type=[{ user_file_types } ]." )
517
524
518
525
# TODO: Consider images in chat
519
526
prompt_list = []
@@ -525,6 +532,8 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
525
532
prompt_list .append (f"Human input: { content } " )
526
533
elif UserChat .is_ai (entry ):
527
534
prompt_list .append (f"AI response: { content } " )
535
+ # prompt_list includes only text (no images/video)
536
+ # from user_chat.history
528
537
prompt_list .append (prompt )
529
538
context_prompt = "\n \n " .join (prompt_list )
530
539
@@ -555,12 +564,16 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
555
564
if is_multimodal :
556
565
user_file_parts = []
557
566
if user_file_bytes is not None :
567
+ # user_file_bytes refers to a single image and so we index into
568
+ # user_file_types (a list) to get a single mime type
558
569
user_file_parts = [Part .from_data (user_file_bytes ,
559
- mime_type = user_file_type )]
570
+ mime_type = user_file_types [ 0 ] )]
560
571
elif user_file_urls is not None :
572
+ # user_file_urls and user_file_types are same-length lists
573
+ # referring to one or more images
561
574
user_file_parts = [
562
575
Part .from_uri (user_file_url , mime_type = user_file_type )
563
- for user_file_url in user_file_urls
576
+ for user_file_url , user_file_type in zip ( user_file_urls , user_file_types )
564
577
]
565
578
else :
566
579
raise RuntimeError (
0 commit comments