38
38
KEY_MODEL_PARAMS , KEY_MODEL_CONTEXT_LENGTH ,
39
39
DEFAULT_LLM_TYPE , DEFAULT_MULTIMODAL_LLM_TYPE )
40
40
from services .langchain_service import langchain_llm_generate
41
+ from services .query .data_source import DataSourceFile
41
42
from utils .errors import ContextWindowExceededException
42
43
43
44
Logger = Logger .get_logger (__file__ )
@@ -107,17 +108,16 @@ async def llm_generate(prompt: str, llm_type: str) -> str:
107
108
except Exception as e :
108
109
raise InternalServerError (str (e )) from e
109
110
110
- async def llm_generate_multimodal (prompt : str , llm_type : str , user_file_types : List [ str ],
111
+ async def llm_generate_multimodal (prompt : str , llm_type : str ,
111
112
user_file_bytes : bytes = None ,
112
- user_file_urls : List [str ] = None ) -> str :
113
+ user_files : List [DataSourceFile ] = None ) -> str :
113
114
"""
114
115
Generate text with an LLM given a file and a prompt.
115
116
Args:
116
117
prompt: the text prompt to pass to the LLM
117
- user_file_bytes: bytes of the file provided by the user
118
- user_file_urls: list of URLs to include in context
119
- user_file_types: list of mime times for files to include in context
120
118
llm_type: the type of LLM to use (default to gemini)
119
+ user_file_bytes: bytes of the file provided by the user
120
+ user_files: list of DataSourceFile objects for file meta data
121
121
Returns:
122
122
the text response: str
123
123
"""
@@ -146,7 +146,7 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_types: L
146
146
f"Vertex model { llm_type } needs to be multimodal" )
147
147
response = await google_llm_predict (prompt , is_chat , is_multimodal ,
148
148
google_llm , None , user_file_bytes ,
149
- user_file_urls , user_file_types )
149
+ user_files )
150
150
else :
151
151
raise ResourceNotFoundException (f"Cannot find llm type '{ llm_type } '" )
152
152
@@ -160,8 +160,7 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_types: L
160
160
async def llm_chat (prompt : str , llm_type : str ,
161
161
user_chat : Optional [UserChat ] = None ,
162
162
user_query : Optional [UserQuery ] = None ,
163
- chat_file_types : Optional [List [str ]] = None ,
164
- chat_file_urls : Optional [List [str ]] = None ,
163
+ chat_files : Optional [List [DataSourceFile ]] = None ,
165
164
chat_file_bytes : Optional [bytes ] = None ) -> str :
166
165
"""
167
166
Send a prompt to a chat model and return string response.
@@ -173,9 +172,9 @@ async def llm_chat(prompt: str, llm_type: str,
173
172
llm_type: the type of LLM to use
174
173
user_chat (optional): a user chat to use for context
175
174
user_query (optional): a user query to use for context
176
- chat_file_bytes (bytes): bytes of file to include in chat context
177
- chat_file_urls (List[str]): urls of files to include in chat context
178
- chat_file_types (List[str]): mime types of files to include in chat context
175
+ chat_files (optional) (List[DataSourceFile]): files to include in chat context
176
+ chat_file_bytes (optional) (bytes): bytes of file to include in chat context
177
+
179
178
Returns:
180
179
the text response: str
181
180
"""
@@ -185,22 +184,19 @@ async def llm_chat(prompt: str, llm_type: str,
185
184
f" user_chat=[{ user_chat } ]"
186
185
f" user_query=[{ user_query } ]"
187
186
f" chat_file_bytes=[{ chat_file_bytes_log } ]"
188
- f" chat_file_urls=[{ chat_file_urls } ]"
189
- f" chat_file_type=[{ chat_file_types } ]" )
187
+ f" chat_files=[{ chat_files } ]" )
190
188
191
189
if llm_type not in get_model_config ().get_chat_llm_types ():
192
190
raise ResourceNotFoundException (f"Cannot find chat llm type '{ llm_type } '" )
193
191
194
192
# validate chat file params
195
193
is_multimodal = False
196
- if chat_file_bytes is not None or chat_file_urls :
197
- if chat_file_bytes is not None and chat_file_urls :
194
+ if chat_file_bytes is not None or chat_files :
195
+ if chat_file_bytes is not None and chat_files :
198
196
raise InternalServerError (
199
- "Must set only one of chat_file_bytes/chat_file_urls " )
197
+ "Must set only one of chat_file_bytes/chat_files " )
200
198
if llm_type not in get_provider_models (PROVIDER_VERTEX ):
201
199
raise InternalServerError ("Chat files only supported for Vertex" )
202
- if chat_file_types is None :
203
- raise InternalServerError ("Mime type must be passed for chat file" )
204
200
is_multimodal = True
205
201
206
202
try :
@@ -244,7 +240,7 @@ async def llm_chat(prompt: str, llm_type: str,
244
240
response = await google_llm_predict (prompt , is_chat , is_multimodal ,
245
241
google_llm , user_chat ,
246
242
chat_file_bytes ,
247
- chat_file_urls , chat_file_types )
243
+ chat_files )
248
244
elif llm_type in get_provider_models (PROVIDER_LANGCHAIN ):
249
245
response = await langchain_llm_generate (prompt , llm_type , user_chat )
250
246
return response
@@ -496,9 +492,8 @@ async def model_garden_predict(prompt: str,
496
492
497
493
async def google_llm_predict (prompt : str , is_chat : bool , is_multimodal : bool ,
498
494
google_llm : str , user_chat = None ,
499
- user_file_bytes : Optional [bytes ]= None ,
500
- user_file_urls : Optional [List [str ]]= None ,
501
- user_file_types : Optional [List [str ]]= None ) -> str :
495
+ user_file_bytes : bytes = None ,
496
+ user_files : List [DataSourceFile ]= None ) -> str :
502
497
"""
503
498
Generate text with a Google multimodal LLM given a prompt.
504
499
Args:
@@ -508,8 +503,7 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
508
503
google_llm: name of the vertex llm model
509
504
user_chat: chat history
510
505
user_file_bytes: the bytes of the file provided by the user
511
- user_file_urls: list of urls of files provided by the user
512
- user_file_types: list of mime types of the files provided by the user
506
+ user_files: list of DataSourceFiles for files provided by the user
513
507
Returns:
514
508
the text response.
515
509
"""
@@ -519,8 +513,7 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
519
513
f" prompt=[{ prompt } ], is_chat=[{ is_chat } ],"
520
514
f" is_multimodal=[{ is_multimodal } ], google_llm=[{ google_llm } ],"
521
515
f" user_file_bytes=[{ user_file_bytes_log } ],"
522
- f" user_file_urls=[{ user_file_urls } ],"
523
- f" user_file_type=[{ user_file_types } ]." )
516
+ f" user_files=[{ user_files } ]" )
524
517
525
518
# TODO: Consider images in chat
526
519
prompt_list = []
@@ -563,21 +556,20 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
563
556
chat_model = GenerativeModel (google_llm )
564
557
if is_multimodal :
565
558
user_file_parts = []
566
- if user_file_bytes is not None :
559
+ if user_file_bytes is not None and user_files is not None :
567
560
# user_file_bytes refers to a single image and so we index into
568
- # user_file_types (a list) to get a single mime type
561
+ # user_files (a list) to get a single mime type
569
562
user_file_parts = [Part .from_data (user_file_bytes ,
570
- mime_type = user_file_types [0 ])]
571
- elif user_file_urls is not None :
572
- # user_file_urls and user_file_types are same-length lists
573
- # referring to one or more images
563
+ mime_type = user_files [0 ].mime_type )]
564
+ elif user_files is not None :
565
+ # user_files is a list referring to one or more images
574
566
user_file_parts = [
575
- Part .from_uri (user_file_url , mime_type = user_file_type )
576
- for user_file_url , user_file_type in zip ( user_file_urls , user_file_types )
567
+ Part .from_uri (user_file . gcs_path , mime_type = user_file . mime_type )
568
+ for user_file in user_files
577
569
]
578
570
else :
579
571
raise RuntimeError (
580
- "if is_multimodal one of user_file_bytes or user_file_urls must be set" )
572
+ "if is_multi user_files must be set" )
581
573
context_list = [* user_file_parts , context_prompt ]
582
574
Logger .info (f"context list { context_list } " )
583
575
generation_config = GenerationConfig (** parameters )
0 commit comments