Skip to content

Commit 80e3804

Browse files
authored
Merge pull request #54 from GPS-Solutions/resolve_merge_conflicts_from_main_to_dev
Merge aica:main to aica:dev
2 parents 02bd05f + 1ff84b5 commit 80e3804

File tree

8 files changed

+259
-85
lines changed

8 files changed

+259
-85
lines changed

components/llm_service/src/routes/chat.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -273,17 +273,15 @@ async def create_user_chat(
273273
if prompt is None or prompt == "":
274274
return BadRequest("Missing or invalid payload parameters")
275275

276-
# process chat file: upload to GCS and determine mime type
277-
chat_file_type = None
276+
# process chat file(s): upload to GCS and determine mime type
278277
chat_file_bytes = None
279-
chat_file_urls = None
278+
chat_files = None
280279
if chat_file is not None or chat_file_url is not None:
281-
chat_file_urls, chat_file_type = \
282-
await process_chat_file(chat_file, chat_file_url)
280+
chat_files = await process_chat_file(chat_file, chat_file_url)
283281

284282
# only read chat file bytes if for some reason we can't
285-
# upload the file to GCS
286-
if not chat_file_urls and chat_file is not None:
283+
# upload the file(s) to GCS
284+
if not chat_files and chat_file is not None:
287285
await chat_file.seek(0)
288286
chat_file_bytes = await chat_file.read()
289287

@@ -293,9 +291,8 @@ async def create_user_chat(
293291
# generate text from prompt
294292
response = await llm_chat(prompt,
295293
llm_type,
296-
chat_file_types=[chat_file_type],
297-
chat_file_bytes=chat_file_bytes,
298-
chat_file_urls=chat_file_urls)
294+
chat_files=chat_files,
295+
chat_file_bytes=chat_file_bytes)
299296

300297
# create new chat for user
301298
user_chat = UserChat(user_id=user.user_id, llm_type=llm_type,

components/llm_service/src/services/llm_generate.py

+27-35
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
KEY_MODEL_PARAMS, KEY_MODEL_CONTEXT_LENGTH,
3939
DEFAULT_LLM_TYPE, DEFAULT_MULTIMODAL_LLM_TYPE)
4040
from services.langchain_service import langchain_llm_generate
41+
from services.query.data_source import DataSourceFile
4142
from utils.errors import ContextWindowExceededException
4243

4344
Logger = Logger.get_logger(__file__)
@@ -107,17 +108,16 @@ async def llm_generate(prompt: str, llm_type: str) -> str:
107108
except Exception as e:
108109
raise InternalServerError(str(e)) from e
109110

110-
async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_types: List[str],
111+
async def llm_generate_multimodal(prompt: str, llm_type: str,
111112
user_file_bytes: bytes = None,
112-
user_file_urls: List[str] = None) -> str:
113+
user_files: List[DataSourceFile] = None) -> str:
113114
"""
114115
Generate text with an LLM given a file and a prompt.
115116
Args:
116117
prompt: the text prompt to pass to the LLM
117-
user_file_bytes: bytes of the file provided by the user
118-
user_file_urls: list of URLs to include in context
119-
user_file_types: list of mime times for files to include in context
120118
llm_type: the type of LLM to use (default to gemini)
119+
user_file_bytes: bytes of the file provided by the user
120+
user_files: list of DataSourceFile objects for file meta data
121121
Returns:
122122
the text response: str
123123
"""
@@ -146,7 +146,7 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_types: L
146146
f"Vertex model {llm_type} needs to be multimodal")
147147
response = await google_llm_predict(prompt, is_chat, is_multimodal,
148148
google_llm, None, user_file_bytes,
149-
user_file_urls, user_file_types)
149+
user_files)
150150
else:
151151
raise ResourceNotFoundException(f"Cannot find llm type '{llm_type}'")
152152

@@ -160,8 +160,7 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_types: L
160160
async def llm_chat(prompt: str, llm_type: str,
161161
user_chat: Optional[UserChat] = None,
162162
user_query: Optional[UserQuery] = None,
163-
chat_file_types: Optional[List[str]] = None,
164-
chat_file_urls: Optional[List[str]] = None,
163+
chat_files: Optional[List[DataSourceFile]] = None,
165164
chat_file_bytes: Optional[bytes] = None) -> str:
166165
"""
167166
Send a prompt to a chat model and return string response.
@@ -173,9 +172,9 @@ async def llm_chat(prompt: str, llm_type: str,
173172
llm_type: the type of LLM to use
174173
user_chat (optional): a user chat to use for context
175174
user_query (optional): a user query to use for context
176-
chat_file_bytes (bytes): bytes of file to include in chat context
177-
chat_file_urls (List[str]): urls of files to include in chat context
178-
chat_file_types (List[str]): mime types of files to include in chat context
175+
chat_files (optional) (List[DataSourceFile]): files to include in chat context
176+
chat_file_bytes (optional) (bytes): bytes of file to include in chat context
177+
179178
Returns:
180179
the text response: str
181180
"""
@@ -185,22 +184,19 @@ async def llm_chat(prompt: str, llm_type: str,
185184
f" user_chat=[{user_chat}]"
186185
f" user_query=[{user_query}]"
187186
f" chat_file_bytes=[{chat_file_bytes_log}]"
188-
f" chat_file_urls=[{chat_file_urls}]"
189-
f" chat_file_type=[{chat_file_types}]")
187+
f" chat_files=[{chat_files}]")
190188

191189
if llm_type not in get_model_config().get_chat_llm_types():
192190
raise ResourceNotFoundException(f"Cannot find chat llm type '{llm_type}'")
193191

194192
# validate chat file params
195193
is_multimodal = False
196-
if chat_file_bytes is not None or chat_file_urls:
197-
if chat_file_bytes is not None and chat_file_urls:
194+
if chat_file_bytes is not None or chat_files:
195+
if chat_file_bytes is not None and chat_files:
198196
raise InternalServerError(
199-
"Must set only one of chat_file_bytes/chat_file_urls")
197+
"Must set only one of chat_file_bytes/chat_files")
200198
if llm_type not in get_provider_models(PROVIDER_VERTEX):
201199
raise InternalServerError("Chat files only supported for Vertex")
202-
if chat_file_types is None:
203-
raise InternalServerError("Mime type must be passed for chat file")
204200
is_multimodal = True
205201

206202
try:
@@ -244,7 +240,7 @@ async def llm_chat(prompt: str, llm_type: str,
244240
response = await google_llm_predict(prompt, is_chat, is_multimodal,
245241
google_llm, user_chat,
246242
chat_file_bytes,
247-
chat_file_urls, chat_file_types)
243+
chat_files)
248244
elif llm_type in get_provider_models(PROVIDER_LANGCHAIN):
249245
response = await langchain_llm_generate(prompt, llm_type, user_chat)
250246
return response
@@ -496,9 +492,8 @@ async def model_garden_predict(prompt: str,
496492

497493
async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
498494
google_llm: str, user_chat=None,
499-
user_file_bytes: Optional[bytes]=None,
500-
user_file_urls: Optional[List[str]]=None,
501-
user_file_types: Optional[List[str]]=None) -> str:
495+
user_file_bytes: bytes=None,
496+
user_files: List[DataSourceFile]=None) -> str:
502497
"""
503498
Generate text with a Google multimodal LLM given a prompt.
504499
Args:
@@ -508,8 +503,7 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
508503
google_llm: name of the vertex llm model
509504
user_chat: chat history
510505
user_file_bytes: the bytes of the file provided by the user
511-
user_file_urls: list of urls of files provided by the user
512-
user_file_types: list of mime types of the files provided by the user
506+
user_files: list of DataSourceFiles for files provided by the user
513507
Returns:
514508
the text response.
515509
"""
@@ -519,8 +513,7 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
519513
f" prompt=[{prompt}], is_chat=[{is_chat}],"
520514
f" is_multimodal=[{is_multimodal}], google_llm=[{google_llm}],"
521515
f" user_file_bytes=[{user_file_bytes_log}],"
522-
f" user_file_urls=[{user_file_urls}],"
523-
f" user_file_type=[{user_file_types}].")
516+
f" user_files=[{user_files}]")
524517

525518
# TODO: Consider images in chat
526519
prompt_list = []
@@ -563,21 +556,20 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
563556
chat_model = GenerativeModel(google_llm)
564557
if is_multimodal:
565558
user_file_parts = []
566-
if user_file_bytes is not None:
559+
if user_file_bytes is not None and user_files is not None:
567560
# user_file_bytes refers to a single image and so we index into
568-
# user_file_types (a list) to get a single mime type
561+
# user_files (a list) to get a single mime type
569562
user_file_parts = [Part.from_data(user_file_bytes,
570-
mime_type=user_file_types[0])]
571-
elif user_file_urls is not None:
572-
# user_file_urls and user_file_types are same-length lists
573-
# referring to one or more images
563+
mime_type=user_files[0].mime_type)]
564+
elif user_files is not None:
565+
# user_files is a list referring to one or more images
574566
user_file_parts = [
575-
Part.from_uri(user_file_url, mime_type=user_file_type)
576-
for user_file_url, user_file_type in zip(user_file_urls, user_file_types)
567+
Part.from_uri(user_file.gcs_path, mime_type=user_file.mime_type)
568+
for user_file in user_files
577569
]
578570
else:
579571
raise RuntimeError(
580-
"if is_multimodal one of user_file_bytes or user_file_urls must be set")
572+
"if is_multi user_files must be set")
581573
context_list = [*user_file_parts, context_prompt]
582574
Logger.info(f"context list {context_list}")
583575
generation_config = GenerationConfig(**parameters)

components/llm_service/src/services/llm_generate_test.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
clean_firestore)
3838
from common.utils.logging_handler import Logger
3939
from schemas.schema_examples import (CHAT_EXAMPLE, USER_EXAMPLE)
40+
from services.query.data_source import DataSourceFile
4041

4142
Logger = Logger.get_logger(__file__)
4243

@@ -77,7 +78,6 @@
7778
}
7879

7980
FAKE_FILE_NAME = "test.png"
80-
FAKE_FILE_TYPE = "image/png"
8181
FAKE_PROMPT = "test prompt"
8282

8383

@@ -147,7 +147,7 @@ async def test_llm_generate_google(clean_firestore):
147147

148148

149149
@pytest.mark.asyncio
150-
async def test_llm_generate_multimodal(clean_firestore):
150+
async def test_llm_generate_multi_file(clean_firestore):
151151
get_model_config().llm_model_providers = {
152152
PROVIDER_VERTEX: TEST_VERTEX_CONFIG
153153
}
@@ -159,18 +159,38 @@ async def test_llm_generate_multimodal(clean_firestore):
159159
os.remove(FAKE_FILE_NAME)
160160
fake_upload_file = UploadFile(file=fake_file, filename=FAKE_FILE_NAME)
161161
fake_file_bytes = await fake_upload_file.read()
162-
162+
fake_file_data = [DataSourceFile(mime_type="image/png")]
163163
with mock.patch(
164164
"vertexai.preview.generative_models.GenerativeModel.generate_content_async",
165165
return_value=FAKE_GOOGLE_RESPONSE):
166166
response = await llm_generate_multimodal(FAKE_PROMPT,
167167
VERTEX_LLM_TYPE_GEMINI_PRO_VISION,
168-
[FAKE_FILE_TYPE],
169-
fake_file_bytes)
168+
fake_file_bytes,
169+
fake_file_data)
170170
fake_file.close()
171171
assert response == FAKE_GENERATE_RESPONSE
172172

173173

174+
@pytest.mark.asyncio
175+
async def test_llm_generate_multi_url(clean_firestore):
176+
get_model_config().llm_model_providers = {
177+
PROVIDER_VERTEX: TEST_VERTEX_CONFIG
178+
}
179+
get_model_config().llm_models = TEST_VERTEX_CONFIG
180+
181+
fake_file_data = [DataSourceFile(mime_type="image/png",
182+
gcs_path="gs://fake_bucket/file.png")]
183+
fake_file_bytes = None
184+
with mock.patch(
185+
"vertexai.preview.generative_models.GenerativeModel.generate_content_async",
186+
return_value=FAKE_GOOGLE_RESPONSE):
187+
response = await llm_generate_multimodal(FAKE_PROMPT,
188+
VERTEX_LLM_TYPE_GEMINI_PRO_VISION,
189+
fake_file_bytes,
190+
fake_file_data)
191+
assert response == FAKE_GENERATE_RESPONSE
192+
193+
174194
@pytest.mark.asyncio
175195
async def test_llm_chat_google(clean_firestore, test_chat):
176196
get_model_config().llm_model_providers = {

components/llm_service/src/services/query/data_source.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,28 @@ def __init__(self,
5454
src_url:str=None,
5555
local_path:str=None,
5656
gcs_path:str=None,
57-
doc_id:str=None):
57+
doc_id:str=None,
58+
mime_type:str=None):
5859
self.doc_name = doc_name
5960
self.src_url = src_url
6061
self.local_path = local_path
6162
self.gcs_path = gcs_path
6263
self.doc_id = doc_id
64+
self.mime_type = mime_type
65+
66+
def __repr__(self) -> str:
67+
"""
68+
Log-friendly string representation of a DataSourceFile
69+
"""
70+
return (
71+
f"DataSourceFile(doc_name={self.doc_name}, "
72+
f"src_url={self.src_url}, "
73+
f"local_path={self.local_path}, "
74+
f"gcs_path={self.gcs_path}, "
75+
f"doc_id={self.doc_id}, "
76+
f"mime_type={self.mime_type})"
77+
)
78+
6379

6480
class DataSource:
6581
"""

components/llm_service/src/services/query/query_service.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
MatchingEngineVectorStore,
5050
PostgresVectorStore,
5151
NUM_MATCH_RESULTS)
52-
from services.query.data_source import DataSource
52+
from services.query.data_source import DataSource, DataSourceFile
5353
from services.query.web_datasource import WebDataSource
5454
from services.query.sharepoint_datasource import SharePointDataSource
5555
from services.query.vertex_search import (build_vertex_search,
@@ -186,23 +186,21 @@ async def query_generate(
186186

187187
# generate list of URLs for additional context
188188
# (from non-text info in query_references)
189-
context_urls = []
190-
context_urls_mimetype = []
189+
context_files = []
191190
for ref in query_references:
192191
if hasattr(ref, "modality") and ref.modality != "text":
193192
if hasattr(ref, "chunk_url"):
194193
ref_filename = ref.chunk_url
195194
ref_mimetype = validate_multimodal_file_type(file_name=ref_filename,
196195
file_b64=None)
197-
context_urls.append(ref_filename)
198-
context_urls_mimetype.append(ref_mimetype)
199-
# TODO: If ref is a video chunk, then update ref.chunk_url
200-
# according to ref.timestamp_start and ref.timestamp_stop
196+
context_files.append(DataSourceFile(gcs_path=ref_filename,
197+
mime_type=ref_mimetype))
198+
# TODO: If ref is a video chunk, then update new element of
199+
# context_files according to ref.timestamp_start and ref.timestamp_stop
201200

202201
# send prompt and additional context to model
203202
question_response = await llm_chat(question_prompt, llm_type,
204-
chat_file_types=context_urls_mimetype,
205-
chat_file_urls=context_urls)
203+
chat_files=context_files)
206204

207205
# update user query with response
208206
if user_query:

components/llm_service/src/services/query/web_datasource.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import sys
2626
import tempfile
2727
from pathlib import Path
28-
from typing import List
28+
from typing import List, Union
2929
from scrapy import signals
3030
from scrapy.crawler import CrawlerProcess
3131
from scrapy.linkextractors import LinkExtractor
@@ -42,14 +42,19 @@
4242

4343
Logger = Logger.get_logger(__file__)
4444

45-
def save_content(filepath: str, file_name: str, content: str) -> None:
45+
def save_content(filepath: str, file_name: str,
46+
content: Union[str, bytes]) -> None:
4647
"""
4748
Save content in a file in a local directory
4849
"""
4950
Logger.info(f"Saving {file_name} to {filepath}")
5051
doc_filepath = os.path.join(filepath, file_name)
51-
with open(doc_filepath, "w", encoding="utf-8") as f:
52-
f.write(content)
52+
if isinstance(content, bytes):
53+
with open(doc_filepath, "wb") as f:
54+
f.write(content)
55+
else:
56+
with open(doc_filepath, "w", encoding="utf-8") as f:
57+
f.write(content)
5358
Logger.info(f"{len(content)} bytes written")
5459
return doc_filepath
5560

0 commit comments

Comments
 (0)