Skip to content

Commit c82f115

Browse files
authored
Merge pull request #48 from smcazares/494-prompt-build
494 prompt build
2 parents 62b5eaa + d7d32c6 commit c82f115

File tree

7 files changed

+65
-25
lines changed

7 files changed

+65
-25
lines changed

components/frontend_streamlit/src/pages/4_Query.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,20 @@ def chat_content():
9696
chunk_url = chunk_url.replace("gs://",
9797
"https://storage.googleapis.com/", 1)
9898

99+
document_url = reference["document_url"]
99100
if modality == "text":
100-
document_url = reference["document_url"]
101101
document_text = reference["document_text"]
102102
st.text_area(
103-
f"Reference: {document_url}",
103+
f"\nReference {query_index}: {document_url}",
104104
document_text,
105105
key=f"ref_{query_index}")
106106
elif modality == "image" and chunk_type in [".pdf",
107107
".png", ".jpg", ".jpeg", ".gif", ".bmp"]:
108108
# .tif/.tiff not available, all other file types are untested
109+
page = reference["page"]
110+
st.write(
111+
f"\nReference {query_index}: {document_url}, Page {page+1}",
112+
key=f"ref_{query_index}")
109113
st.image(chunk_url)
110114
else:
111115
logging.error("Reference modality unknown")

components/llm_service/src/routes/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ async def create_user_chat(
293293
# generate text from prompt
294294
response = await llm_chat(prompt,
295295
llm_type,
296-
chat_file_type=chat_file_type,
296+
chat_file_types=[chat_file_type],
297297
chat_file_bytes=chat_file_bytes,
298298
chat_file_urls=chat_file_urls)
299299

components/llm_service/src/routes/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ async def generate_multimodal(gen_config: LLMMultimodalGenerateModel):
279279

280280
try:
281281
user_file_bytes = b64decode(user_file_b64)
282-
result = await llm_generate_multimodal(prompt, file_mime_type, llm_type,
282+
result = await llm_generate_multimodal(prompt, [file_mime_type], llm_type,
283283
user_file_bytes)
284284

285285
return {

components/llm_service/src/services/llm_generate.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ async def llm_generate(prompt: str, llm_type: str) -> str:
107107
except Exception as e:
108108
raise InternalServerError(str(e)) from e
109109

110-
async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_type: str,
110+
async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_types: List[str],
111111
user_file_bytes: bytes = None,
112112
user_file_urls: List[str] = None) -> str:
113113
"""
@@ -116,6 +116,7 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_type: st
116116
prompt: the text prompt to pass to the LLM
117117
user_file_bytes: bytes of the file provided by the user
118118
user_file_urls: list of URLs to include in context
119+
user_file_types: list of mime times for files to include in context
119120
llm_type: the type of LLM to use (default to gemini)
120121
Returns:
121122
the text response: str
@@ -145,7 +146,7 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_type: st
145146
f"Vertex model {llm_type} needs to be multimodal")
146147
response = await google_llm_predict(prompt, is_chat, is_multimodal,
147148
google_llm, None, user_file_bytes,
148-
user_file_urls, user_file_type)
149+
user_file_urls, user_file_types)
149150
else:
150151
raise ResourceNotFoundException(f"Cannot find llm type '{llm_type}'")
151152

@@ -159,9 +160,9 @@ async def llm_generate_multimodal(prompt: str, llm_type: str, user_file_type: st
159160
async def llm_chat(prompt: str, llm_type: str,
160161
user_chat: Optional[UserChat] = None,
161162
user_query: Optional[UserQuery] = None,
162-
chat_file_type: str = None,
163-
chat_file_urls: List[str] = None,
164-
chat_file_bytes: bytes = None) -> str:
163+
chat_file_types: Optional[List[str]] = None,
164+
chat_file_urls: Optional[List[str]] = None,
165+
chat_file_bytes: Optional[bytes] = None) -> str:
165166
"""
166167
Send a prompt to a chat model and return string response.
167168
Supports including a file in the chat context, either by URL or
@@ -174,7 +175,7 @@ async def llm_chat(prompt: str, llm_type: str,
174175
user_query (optional): a user query to use for context
175176
chat_file_bytes (bytes): bytes of file to include in chat context
176177
chat_file_urls (List[str]): urls of files to include in chat context
177-
chat_file_type (str): mime type of file to include in chat context
178+
chat_file_types (List[str]): mime types of files to include in chat context
178179
Returns:
179180
the text response: str
180181
"""
@@ -185,7 +186,7 @@ async def llm_chat(prompt: str, llm_type: str,
185186
f" user_query=[{user_query}]"
186187
f" chat_file_bytes=[{chat_file_bytes_log}]"
187188
f" chat_file_urls=[{chat_file_urls}]"
188-
f" chat_file_type=[{chat_file_type}]")
189+
f" chat_file_type=[{chat_file_types}]")
189190

190191
if llm_type not in get_model_config().get_chat_llm_types():
191192
raise ResourceNotFoundException(f"Cannot find chat llm type '{llm_type}'")
@@ -198,7 +199,7 @@ async def llm_chat(prompt: str, llm_type: str,
198199
"Must set only one of chat_file_bytes/chat_file_urls")
199200
if llm_type not in get_provider_models(PROVIDER_VERTEX):
200201
raise InternalServerError("Chat files only supported for Vertex")
201-
if chat_file_type is None:
202+
if chat_file_types is None:
202203
raise InternalServerError("Mime type must be passed for chat file")
203204
is_multimodal = True
204205

@@ -209,6 +210,8 @@ async def llm_chat(prompt: str, llm_type: str,
209210
if user_chat is not None or user_query is not None:
210211
context_prompt = get_context_prompt(
211212
user_chat=user_chat, user_query=user_query)
213+
# context_prompt includes only text (no images/video) from
214+
# user_chat.history and user_query.history
212215
prompt = context_prompt + "\n" + prompt
213216

214217
# check whether the context length exceeds the limit for the model
@@ -241,7 +244,7 @@ async def llm_chat(prompt: str, llm_type: str,
241244
response = await google_llm_predict(prompt, is_chat, is_multimodal,
242245
google_llm, user_chat,
243246
chat_file_bytes,
244-
chat_file_urls, chat_file_type)
247+
chat_file_urls, chat_file_types)
245248
elif llm_type in get_provider_models(PROVIDER_LANGCHAIN):
246249
response = await langchain_llm_generate(prompt, llm_type, user_chat)
247250
return response
@@ -271,6 +274,7 @@ def get_context_prompt(user_chat=None,
271274
prompt_list.append(f"Human input: {content}")
272275
elif UserChat.is_ai(entry):
273276
prompt_list.append(f"AI response: {content}")
277+
# prompt_list includes only text from user_chat.history
274278

275279
if user_query is not None:
276280
history = user_query.history
@@ -280,6 +284,7 @@ def get_context_prompt(user_chat=None,
280284
prompt_list.append(f"Human input: {content}")
281285
elif UserQuery.is_ai(entry):
282286
prompt_list.append(f"AI response: {content}")
287+
# prompt_list includes only text from user_query.history
283288

284289
context_prompt = "\n\n".join(prompt_list)
285290

@@ -294,6 +299,8 @@ def check_context_length(prompt, llm_type):
294299
"""
295300
# check if prompt exceeds context window length for model
296301
# assume a constant relationship between tokens and chars
302+
# TODO: Recalculate max_context_length for text prompt,
303+
# subtracting out tokens used by non-text context (image, video, etc)
297304
token_length = len(prompt) / CHARS_PER_TOKEN
298305
max_context_length = get_model_config_value(llm_type,
299306
KEY_MODEL_CONTEXT_LENGTH,
@@ -489,9 +496,9 @@ async def model_garden_predict(prompt: str,
489496

490497
async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
491498
google_llm: str, user_chat=None,
492-
user_file_bytes: bytes=None,
493-
user_file_urls: List[str]=None,
494-
user_file_type: str=None) -> str:
499+
user_file_bytes: Optional[bytes]=None,
500+
user_file_urls: Optional[List[str]]=None,
501+
user_file_types: Optional[List[str]]=None) -> str:
495502
"""
496503
Generate text with a Google multimodal LLM given a prompt.
497504
Args:
@@ -502,7 +509,7 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
502509
user_chat: chat history
503510
user_file_bytes: the bytes of the file provided by the user
504511
user_file_urls: list of urls of files provided by the user
505-
user_file_type: mime type of the file provided by the user
512+
user_file_types: list of mime types of the files provided by the user
506513
Returns:
507514
the text response.
508515
"""
@@ -513,7 +520,7 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
513520
f" is_multimodal=[{is_multimodal}], google_llm=[{google_llm}],"
514521
f" user_file_bytes=[{user_file_bytes_log}],"
515522
f" user_file_urls=[{user_file_urls}],"
516-
f" user_file_type=[{user_file_type}].")
523+
f" user_file_type=[{user_file_types}].")
517524

518525
# TODO: Consider images in chat
519526
prompt_list = []
@@ -525,6 +532,8 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
525532
prompt_list.append(f"Human input: {content}")
526533
elif UserChat.is_ai(entry):
527534
prompt_list.append(f"AI response: {content}")
535+
# prompt_list includes only text (no images/video)
536+
# from user_chat.history
528537
prompt_list.append(prompt)
529538
context_prompt = "\n\n".join(prompt_list)
530539

@@ -555,12 +564,16 @@ async def google_llm_predict(prompt: str, is_chat: bool, is_multimodal: bool,
555564
if is_multimodal:
556565
user_file_parts = []
557566
if user_file_bytes is not None:
567+
# user_file_bytes refers to a single image and so we index into
568+
# user_file_types (a list) to get a single mime type
558569
user_file_parts = [Part.from_data(user_file_bytes,
559-
mime_type=user_file_type)]
570+
mime_type=user_file_types[0])]
560571
elif user_file_urls is not None:
572+
# user_file_urls and user_file_types are same-length lists
573+
# referring to one or more images
561574
user_file_parts = [
562575
Part.from_uri(user_file_url, mime_type=user_file_type)
563-
for user_file_url in user_file_urls
576+
for user_file_url, user_file_type in zip(user_file_urls, user_file_types)
564577
]
565578
else:
566579
raise RuntimeError(

components/llm_service/src/services/llm_generate_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ async def test_llm_generate_multimodal(clean_firestore):
165165
return_value=FAKE_GOOGLE_RESPONSE):
166166
response = await llm_generate_multimodal(FAKE_PROMPT,
167167
VERTEX_LLM_TYPE_GEMINI_PRO_VISION,
168-
FAKE_FILE_TYPE,
168+
[FAKE_FILE_TYPE],
169169
fake_file_bytes)
170170
fake_file.close()
171171
assert response == FAKE_GENERATE_RESPONSE

components/llm_service/src/services/query/query_prompts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ def get_question_prompt(prompt: str,
3131
""" Create question prompt with context for LLM """
3232
Logger.info(f"Creating question prompt with context "
3333
f"for LLM prompt=[{prompt}]")
34-
context_list = [ref.document_text for ref in query_context]
34+
context_list = []
35+
for ref in query_context:
36+
if hasattr(ref, "modality") and ref.modality=="text":
37+
if hasattr(ref, "document_text"):
38+
context_list.append(ref.document_text)
3539
text_context = "\n\n".join(context_list)
3640

3741
if llm_type == TRUSS_LLM_LLAMA2_CHAT:

components/llm_service/src/services/query/query_service.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
delete_vertex_search)
5858
from utils.errors import (NoDocumentsIndexedException,
5959
ContextWindowExceededException)
60+
from utils.file_helper import validate_multimodal_file_type
6061
from utils import text_helper
6162
from config import (PROJECT_ID, DEFAULT_QUERY_CHAT_MODEL,
6263
DEFAULT_MULTIMODAL_LLM_TYPE,
@@ -176,14 +177,32 @@ async def query_generate(
176177
prompt, None, user_id, q_engine, query_references, user_query)
177178

178179
# generate question prompt
180+
# (from user's text prompt plus text info in query_references)
179181
question_prompt, query_references = \
180182
await generate_question_prompt(prompt,
181183
llm_type,
182184
query_references,
183185
user_query)
184186

185-
# send prompt to model
186-
question_response = await llm_chat(question_prompt, llm_type)
187+
# generate list of URLs for additional context
188+
# (from non-text info in query_references)
189+
context_urls = []
190+
context_urls_mimetype = []
191+
for ref in query_references:
192+
if hasattr(ref, "modality") and ref.modality != "text":
193+
if hasattr(ref, "chunk_url"):
194+
ref_filename = ref.chunk_url
195+
ref_mimetype = validate_multimodal_file_type(file_name=ref_filename,
196+
file_b64=None)
197+
context_urls.append(ref_filename)
198+
context_urls_mimetype.append(ref_mimetype)
199+
# TODO: If ref is a video chunk, then update ref.chunk_url
200+
# according to ref.timestamp_start and ref.timestamp_stop
201+
202+
# send prompt and additional context to model
203+
question_response = await llm_chat(question_prompt, llm_type,
204+
chat_file_types=context_urls_mimetype,
205+
chat_file_urls=context_urls)
187206

188207
# update user query with response
189208
if user_query:
@@ -650,7 +669,7 @@ def update_user_query(prompt: str,
650669
query_references: List[QueryReference],
651670
user_query: UserQuery = None,
652671
query_filter=None) -> \
653-
Tuple[UserQuery, dict]:
672+
Tuple[UserQuery, List[dict]]:
654673
""" Save user query history """
655674
query_reference_dicts = [
656675
ref.get_fields(reformat_datetime=True) for ref in query_references

0 commit comments

Comments
 (0)