Skip to content

Commit 6e1dc06

Browse files
speedstorm1copybara-github
authored andcommitted
feat: Add advanced PDF parsing option for RAG file import
PiperOrigin-RevId: 663391146
1 parent d03468a commit 6e1dc06

File tree

4 files changed

+50
-1
lines changed

4 files changed

+50
-1
lines changed

tests/unit/vertex_rag/test_rag_constants.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from google.cloud.aiplatform_v1beta1 import (
2323
GoogleDriveSource,
2424
RagFileChunkingConfig,
25+
RagFileParsingConfig,
2526
ImportRagFilesConfig,
2627
ImportRagFilesRequest,
2728
ImportRagFilesResponse,
@@ -93,6 +94,7 @@
9394
# GCS
9495
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig()
9596
TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH]
97+
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.use_advanced_pdf_parsing = False
9698
TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest(
9799
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
98100
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS,
@@ -112,18 +114,36 @@
112114
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
113115
)
114116
]
117+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.use_advanced_pdf_parsing = (
118+
False
119+
)
120+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig()
121+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [
122+
GoogleDriveSource.ResourceId(
123+
resource_id=TEST_DRIVE_FOLDER_ID,
124+
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
125+
)
126+
]
127+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.use_advanced_pdf_parsing = (
128+
True
129+
)
115130
TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest(
116131
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
117132
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER,
118133
)
134+
TEST_IMPORT_REQUEST_DRIVE_FOLDER_PARSING = ImportRagFilesRequest(
135+
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
136+
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING,
137+
)
119138
# Google Drive files
120139
TEST_DRIVE_FILE_ID = "456"
121140
TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}"
122141
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig(
123142
rag_file_chunking_config=RagFileChunkingConfig(
124143
chunk_size=TEST_CHUNK_SIZE,
125144
chunk_overlap=TEST_CHUNK_OVERLAP,
126-
)
145+
),
146+
rag_file_parsing_config=RagFileParsingConfig(use_advanced_pdf_parsing=False),
127147
)
128148
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800
129149

tests/unit/vertex_rag/test_rag_data.py

+15
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def import_files_request_eq(returned_request, expected_request):
166166
returned_request.import_rag_files_config.jira_source.jira_queries
167167
== expected_request.import_rag_files_config.jira_source.jira_queries
168168
)
169+
assert (
170+
returned_request.import_rag_files_config.rag_file_parsing_config
171+
== expected_request.import_rag_files_config.rag_file_parsing_config
172+
)
169173

170174

171175
@pytest.mark.usefixtures("google_auth_mock")
@@ -396,6 +400,17 @@ def test_prepare_import_files_request_drive_folders(self, path):
396400
)
397401
import_files_request_eq(request, tc.TEST_IMPORT_REQUEST_DRIVE_FOLDER)
398402

403+
@pytest.mark.parametrize("path", [tc.TEST_DRIVE_FOLDER, tc.TEST_DRIVE_FOLDER_2])
404+
def test_prepare_import_files_request_drive_folders_with_pdf_parsing(self, path):
405+
request = prepare_import_files_request(
406+
corpus_name=tc.TEST_RAG_CORPUS_RESOURCE_NAME,
407+
paths=[path],
408+
chunk_size=tc.TEST_CHUNK_SIZE,
409+
chunk_overlap=tc.TEST_CHUNK_OVERLAP,
410+
use_advanced_pdf_parsing=True,
411+
)
412+
import_files_request_eq(request, tc.TEST_IMPORT_REQUEST_DRIVE_FOLDER_PARSING)
413+
399414
def test_prepare_import_files_request_drive_files(self):
400415
paths = [tc.TEST_DRIVE_FILE]
401416
request = prepare_import_files_request(

vertexai/preview/rag/rag_data.py

+8
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def import_files(
281281
chunk_overlap: int = 200,
282282
timeout: int = 600,
283283
max_embedding_requests_per_min: int = 1000,
284+
use_advanced_pdf_parsing: Optional[bool] = False,
284285
) -> ImportRagFilesResponse:
285286
"""
286287
Import files to an existing RagCorpus, wait until completion.
@@ -364,6 +365,8 @@ def import_files(
364365
here. If unspecified, a default value of 1,000
365366
QPM would be used.
366367
timeout: Default is 600 seconds.
368+
use_advanced_pdf_parsing: Whether to use advanced PDF
369+
parsing on uploaded files.
367370
Returns:
368371
ImportRagFilesResponse.
369372
"""
@@ -379,6 +382,7 @@ def import_files(
379382
chunk_size=chunk_size,
380383
chunk_overlap=chunk_overlap,
381384
max_embedding_requests_per_min=max_embedding_requests_per_min,
385+
use_advanced_pdf_parsing=use_advanced_pdf_parsing,
382386
)
383387
client = _gapic_utils.create_rag_data_service_client()
384388
try:
@@ -396,6 +400,7 @@ async def import_files_async(
396400
chunk_size: int = 1024,
397401
chunk_overlap: int = 200,
398402
max_embedding_requests_per_min: int = 1000,
403+
use_advanced_pdf_parsing: Optional[bool] = False,
399404
) -> operation_async.AsyncOperation:
400405
"""
401406
Import files to an existing RagCorpus asynchronously.
@@ -479,6 +484,8 @@ async def import_files_async(
479484
page on the project to set an appropriate value
480485
here. If unspecified, a default value of 1,000
481486
QPM would be used.
487+
use_advanced_pdf_parsing: Whether to use advanced PDF
488+
parsing on uploaded files.
482489
Returns:
483490
operation_async.AsyncOperation.
484491
"""
@@ -494,6 +501,7 @@ async def import_files_async(
494501
chunk_size=chunk_size,
495502
chunk_overlap=chunk_overlap,
496503
max_embedding_requests_per_min=max_embedding_requests_per_min,
504+
use_advanced_pdf_parsing=use_advanced_pdf_parsing,
497505
)
498506
async_client = _gapic_utils.create_rag_data_service_async_client()
499507
try:

vertexai/preview/rag/utils/_gapic_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ImportRagFilesConfig,
2424
ImportRagFilesRequest,
2525
RagFileChunkingConfig,
26+
RagFileParsingConfig,
2627
RagCorpus as GapicRagCorpus,
2728
RagFile as GapicRagFile,
2829
SlackSource as GapicSlackSource,
@@ -217,19 +218,24 @@ def prepare_import_files_request(
217218
chunk_size: int = 1024,
218219
chunk_overlap: int = 200,
219220
max_embedding_requests_per_min: int = 1000,
221+
use_advanced_pdf_parsing: bool = False,
220222
) -> ImportRagFilesRequest:
221223
if len(corpus_name.split("/")) != 6:
222224
raise ValueError(
223225
"corpus_name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`"
224226
)
225227

228+
rag_file_parsing_config = RagFileParsingConfig(
229+
use_advanced_pdf_parsing=use_advanced_pdf_parsing,
230+
)
226231
rag_file_chunking_config = RagFileChunkingConfig(
227232
chunk_size=chunk_size,
228233
chunk_overlap=chunk_overlap,
229234
)
230235
import_rag_files_config = ImportRagFilesConfig(
231236
rag_file_chunking_config=rag_file_chunking_config,
232237
max_embedding_requests_per_min=max_embedding_requests_per_min,
238+
rag_file_parsing_config=rag_file_parsing_config,
233239
)
234240

235241
if source is not None:

0 commit comments

Comments
 (0)