Skip to content

Commit ca89229

Browse files
authored
Merge pull request #56 from smcazares/672-text-image-pair
672 text image pair
2 parents b21c7c4 + a5db04a commit ca89229

File tree

3 files changed

+55
-23
lines changed

3 files changed

+55
-23
lines changed

components/common/src/common/models/llm_query.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,17 @@ class QueryReference(BaseModel):
201201
document_id = TextField(required=True) # All modalities
202202
document_url = TextField(required=True) # All modalities
203203
modality = TextField(required=True) # All modalities: text image video audio
204-
chunk_id = TextField(required=False) # All modalities
205-
chunk_url = TextField(required=False) # Image or video or audio only
206-
page = NumberField(required=False) # Text or image only
207-
document_text = TextField(required=False) # Text only
208-
timestamp_start = NumberField(required=False) # Video or audio only
209-
timestamp_stop = NumberField(required=False) # Video or audio only
204+
chunk_id = TextField(required=True) # All modalities
205+
chunk_url = TextField(
206+
required=False, default=None) # Image or video or audio only
207+
page = NumberField(required=False, default=None) # Text or image only
208+
document_text = TextField(required=False, default=None) # Text only
209+
timestamp_start = NumberField(
210+
required=False, default=None) # Video or audio only
211+
timestamp_stop = NumberField(
212+
required=False, default=None) # Video or audio only
213+
linked_ids = ListField(
214+
IDField(), required=False, default=None) # All modalities
210215

211216
def __repr__(self) -> str:
212217
"""
@@ -218,24 +223,23 @@ def __repr__(self) -> str:
218223
document_text_snippet = self.document_text[:min(100,
219224
document_text_num_chars)]
220225
chunk_url = None
221-
page = None
222226
else:
223227
document_text_num_tokens = None
224228
document_text_num_chars = None
225229
document_text_snippet = None
226230
chunk_url = self.chunk_url
227-
page = self.page
228231
return (
229232
f"Query_Ref(query_engine_name={self.query_engine}, "
230233
f"document_id={self.document_id}, "
231234
f"document_url={self.document_url}, "
232235
f"chunk_id={self.chunk_id}, "
233236
f"chunk_url={chunk_url}, "
234237
f"modality={self.modality}, "
235-
f"page={page}, "
238+
f"page={self.page}, "
236239
f"chunk_num_tokens={document_text_num_tokens}, "
237240
f"chunk_num_chars={document_text_num_chars}, "
238-
f"chunk_text={document_text_snippet})"
241+
f"chunk_text={document_text_snippet}, "
242+
f"linked_ids={self.linked_ids})"
239243
)
240244

241245
class Meta:
@@ -360,14 +364,18 @@ class QueryDocumentChunk(BaseModel):
360364
query_document_id = TextField(required=True) # All modalities
361365
index = NumberField(required=True) # All modalities
362366
modality = TextField(required=True) # All modalities: text image video audio
363-
page = NumberField(required=False) # Text or image only
364-
chunk_url = TextField(required=False) # Image or video or audio only
365-
text = TextField(required=False) # Text only
366-
clean_text = TextField(required=False) # Text only (optional)
367-
sentences = ListField(required=False) # Text only (optional)
368-
timestamp_start = NumberField(required=False) # Video or audio only
369-
timestamp_stop = NumberField(required=False) # Video or audio only
370-
linked_ids = ListField(required=False) # All modalities
367+
page = NumberField(required=False, default=None) # Text or image only
368+
chunk_url = TextField(
369+
required=False, default=None) # Image or video or audio only
370+
text = TextField(required=False, default=None) # Text only
371+
clean_text = TextField(required=False, default=None) # Text only (optional)
372+
sentences = ListField(required=False, default=None) # Text only (optional)
373+
timestamp_start = NumberField(
374+
required=False, default=None) # Video or audio only
375+
timestamp_stop = NumberField(
376+
required=False, default=None) # Video or audio only
377+
linked_ids = ListField(
378+
IDField(), required=False, default=None) # All modalities
371379

372380
class Meta:
373381
ignore_none_field = False

components/frontend_streamlit/src/pages/4_Query.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,26 @@ def chat_content():
9595
_, chunk_type = splitext(chunk_url)
9696
chunk_url = chunk_url.replace("gs://",
9797
"https://storage.googleapis.com/", 1)
98-
9998
document_url = reference["document_url"]
99+
page = reference["page"]
100+
if page:
101+
# References from multimodal query engines have page numbers
102+
reference_header = (f"\nReference {query_index}:"
103+
f" {document_url}, Page {page+1}")
104+
else:
105+
# References from text-only query engines do not have page numbers
106+
reference_header = f"\nReference {query_index}: {document_url}"
100107
if modality == "text":
101108
document_text = reference["document_text"]
102109
st.text_area(
103-
f"\nReference {query_index}: {document_url}",
110+
reference_header,
104111
document_text,
105112
key=f"ref_{query_index}")
106113
elif modality == "image" and chunk_type in [".pdf",
107114
".png", ".jpg", ".jpeg", ".gif", ".bmp"]:
108115
# .tif/.tiff not available, all other file types are untested
109-
page = reference["page"]
110116
st.write(
111-
f"\nReference {query_index}: {document_url}, Page {page+1}",
117+
reference_header,
112118
key=f"ref_{query_index}")
113119
st.image(chunk_url)
114120
else:

components/llm_service/src/services/query/query_service.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,20 @@ async def query_search(q_engine: QueryEngine,
422422
query_reference.save()
423423
query_references.append(query_reference)
424424

425+
# Also create a query_reference for other modalities of the same chunk
426+
linked_ids = query_reference.linked_ids
427+
if linked_ids:
428+
for linked_id in linked_ids:
429+
query_doc_chunk_friend = QueryDocumentChunk.find_by_id(linked_id)
430+
query_reference_friend = make_query_reference(
431+
q_engine=q_engine,
432+
query_doc=query_doc,
433+
doc_chunk=query_doc_chunk_friend,
434+
query_embeddings=query_embeddings,
435+
rank_sentences=rank_sentences)
436+
query_reference_friend.save()
437+
query_references.append(query_reference_friend)
438+
425439
Logger.info(f"Retrieved {len(query_references)} "
426440
f"references={query_references}")
427441

@@ -504,6 +518,7 @@ def make_query_reference(q_engine: QueryEngine,
504518
query_reference_dict["document_url"]=query_doc.doc_url
505519
query_reference_dict["modality"]=modality
506520
query_reference_dict["chunk_id"]=doc_chunk.id
521+
query_reference_dict["linked_ids"]=doc_chunk.linked_ids
507522
# For text chunk only
508523
if modality=="text":
509524
query_reference_dict["page"]=doc_chunk.page
@@ -1071,7 +1086,10 @@ async def process_documents(doc_url: str, qe_vector_store: VectorStore,
10711086
for index in linked_indexes:
10721087
query_doc_chunk = \
10731088
QueryDocumentChunk.find_by_index(q_engine.id, index)
1074-
query_doc_chunk.linked_ids = linked_ids
1089+
friend_ids = [friend_id for friend_id in linked_ids
1090+
if friend_id != query_doc_chunk.id]
1091+
query_doc_chunk.linked_ids = friend_ids
1092+
query_doc_chunk.save()
10751093

10761094
else:
10771095
# Use text-only pipeline

0 commit comments

Comments
 (0)