Skip to content

Commit ffe3230

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Update v1beta1 sdk for RagFileTransformationConfig and Ranking protos
PiperOrigin-RevId: 700877396
1 parent 6faa1d0 commit ffe3230

File tree

8 files changed

+250
-83
lines changed

8 files changed

+250
-83
lines changed

tests/unit/vertex_rag/test_rag_constants_preview.py

+49-27
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
EmbeddingModelConfig,
2323
Filter,
2424
HybridSearch,
25+
LlmRanker,
2526
Pinecone,
2627
RagCorpus,
2728
RagFile,
2829
RagResource,
2930
RagRetrievalConfig,
3031
Ranking,
3132
RankService,
32-
LlmRanker,
3333
SharePointSource,
3434
SharePointSources,
3535
SlackChannelsSource,
@@ -44,6 +44,7 @@
4444
from google.cloud.aiplatform_v1beta1 import (
4545
GoogleDriveSource,
4646
RagFileChunkingConfig,
47+
RagFileTransformationConfig,
4748
RagFileParsingConfig,
4849
ImportRagFilesConfig,
4950
ImportRagFilesRequest,
@@ -256,10 +257,22 @@
256257
TEST_RAG_FILE_JSON_ERROR = {"error": {"code": 13}}
257258
TEST_CHUNK_SIZE = 512
258259
TEST_CHUNK_OVERLAP = 100
260+
TEST_RAG_FILE_TRANSFORMATION_CONFIG = RagFileTransformationConfig(
261+
rag_file_chunking_config=RagFileChunkingConfig(
262+
fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking(
263+
chunk_size=TEST_CHUNK_SIZE,
264+
chunk_overlap=TEST_CHUNK_OVERLAP,
265+
),
266+
),
267+
)
259268
# GCS
260-
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig()
269+
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig(
270+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
271+
)
261272
TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH]
262-
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.use_advanced_pdf_parsing = False
273+
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = (
274+
False
275+
)
263276
TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest(
264277
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
265278
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS,
@@ -272,24 +285,28 @@
272285
TEST_DRIVE_FOLDER_2 = (
273286
f"https://drive.google.com/drive/folders/{TEST_DRIVE_FOLDER_ID}?resourcekey=0-eiOT3"
274287
)
275-
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig()
288+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig(
289+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
290+
)
276291
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [
277292
GoogleDriveSource.ResourceId(
278293
resource_id=TEST_DRIVE_FOLDER_ID,
279294
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
280295
)
281296
]
282-
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.use_advanced_pdf_parsing = (
297+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = (
283298
False
284299
)
285-
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig()
300+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig(
301+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
302+
)
286303
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [
287304
GoogleDriveSource.ResourceId(
288305
resource_id=TEST_DRIVE_FOLDER_ID,
289306
resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER,
290307
)
291308
]
292-
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.use_advanced_pdf_parsing = (
309+
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = (
293310
True
294311
)
295312
TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest(
@@ -304,11 +321,12 @@
304321
TEST_DRIVE_FILE_ID = "456"
305322
TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}"
306323
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig(
307-
rag_file_chunking_config=RagFileChunkingConfig(
308-
chunk_size=TEST_CHUNK_SIZE,
309-
chunk_overlap=TEST_CHUNK_OVERLAP,
324+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
325+
rag_file_parsing_config=RagFileParsingConfig(
326+
advanced_parser=RagFileParsingConfig.AdvancedParser(
327+
use_advanced_pdf_parsing=False
328+
)
310329
),
311-
rag_file_parsing_config=RagFileParsingConfig(use_advanced_pdf_parsing=False),
312330
)
313331
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800
314332

@@ -362,11 +380,12 @@
362380
),
363381
],
364382
)
383+
TEST_RAG_FILE_PARSING_CONFIG = RagFileParsingConfig(
384+
advanced_parser=RagFileParsingConfig.AdvancedParser(use_advanced_pdf_parsing=False)
385+
)
365386
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig(
366-
rag_file_chunking_config=RagFileChunkingConfig(
367-
chunk_size=TEST_CHUNK_SIZE,
368-
chunk_overlap=TEST_CHUNK_OVERLAP,
369-
)
387+
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
388+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
370389
)
371390
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [
372391
GapicSlackSource.SlackChannels(
@@ -418,10 +437,8 @@
418437
],
419438
)
420439
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE = ImportRagFilesConfig(
421-
rag_file_chunking_config=RagFileChunkingConfig(
422-
chunk_size=TEST_CHUNK_SIZE,
423-
chunk_overlap=TEST_CHUNK_OVERLAP,
424-
)
440+
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
441+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
425442
)
426443
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [
427444
GapicJiraSource.JiraQueries(
@@ -453,10 +470,8 @@
453470
],
454471
)
455472
TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE = ImportRagFilesConfig(
456-
rag_file_chunking_config=RagFileChunkingConfig(
457-
chunk_size=TEST_CHUNK_SIZE,
458-
chunk_overlap=TEST_CHUNK_OVERLAP,
459-
),
473+
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
474+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
460475
share_point_sources=GapicSharePointSources(
461476
share_point_sources=[
462477
GapicSharePointSources.SharePointSource(
@@ -531,10 +546,7 @@
531546
)
532547

533548
TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesConfig(
534-
rag_file_chunking_config=RagFileChunkingConfig(
535-
chunk_size=TEST_CHUNK_SIZE,
536-
chunk_overlap=TEST_CHUNK_OVERLAP,
537-
),
549+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
538550
share_point_sources=GapicSharePointSources(
539551
share_point_sources=[
540552
GapicSharePointSources.SharePointSource(
@@ -606,3 +618,13 @@
606618
llm_ranker=LlmRanker(model_name="test-llm-ranker"),
607619
),
608620
)
621+
TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE = RagRetrievalConfig(
622+
top_k=2,
623+
filter=Filter(vector_distance_threshold=0.5),
624+
ranking=Ranking(rank_service=RankService(model_name="test-model-name")),
625+
)
626+
TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER = RagRetrievalConfig(
627+
top_k=2,
628+
filter=Filter(vector_distance_threshold=0.5),
629+
ranking=Ranking(llm_ranker=LlmRanker(model_name="test-model-name")),
630+
)

tests/unit/vertex_rag/test_rag_data_preview.py

+43-24
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
prepare_import_files_request,
2222
set_embedding_model_config,
2323
)
24+
from vertexai.rag.utils.resources import (
25+
ChunkingConfig,
26+
TransformationConfig,
27+
)
2428
from google.cloud.aiplatform_v1beta1 import (
2529
VertexRagDataServiceAsyncClient,
2630
VertexRagDataServiceClient,
@@ -327,6 +331,18 @@ def list_rag_files_pager_mock():
327331
yield list_rag_files_pager_mock
328332

329333

334+
def create_transformation_config(
335+
chunk_size: int = test_rag_constants_preview.TEST_CHUNK_SIZE,
336+
chunk_overlap: int = test_rag_constants_preview.TEST_CHUNK_OVERLAP,
337+
):
338+
return TransformationConfig(
339+
chunking_config=ChunkingConfig(
340+
chunk_size=chunk_size,
341+
chunk_overlap=chunk_overlap,
342+
),
343+
)
344+
345+
330346
def rag_corpus_eq(returned_corpus, expected_corpus):
331347
assert returned_corpus.name == expected_corpus.name
332348
assert returned_corpus.display_name == expected_corpus.display_name
@@ -363,6 +379,10 @@ def import_files_request_eq(returned_request, expected_request):
363379
returned_request.import_rag_files_config.rag_file_parsing_config
364380
== expected_request.import_rag_files_config.rag_file_parsing_config
365381
)
382+
assert (
383+
returned_request.import_rag_files_config.rag_file_transformation_config
384+
== expected_request.import_rag_files_config.rag_file_transformation_config
385+
)
366386

367387

368388
@pytest.mark.usefixtures("google_auth_mock")
@@ -795,6 +815,17 @@ def test_delete_file_failure(self):
795815
e.match("Failed in RagFile deletion due to")
796816

797817
def test_prepare_import_files_request_list_gcs_uris(self):
818+
paths = [test_rag_constants_preview.TEST_GCS_PATH]
819+
request = prepare_import_files_request(
820+
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
821+
paths=paths,
822+
transformation_config=create_transformation_config(),
823+
)
824+
import_files_request_eq(
825+
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_GCS
826+
)
827+
828+
def test_prepare_import_files_request_list_gcs_uris_no_transformation_config(self):
798829
paths = [test_rag_constants_preview.TEST_GCS_PATH]
799830
request = prepare_import_files_request(
800831
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
@@ -817,8 +848,7 @@ def test_prepare_import_files_request_drive_folders(self, path):
817848
request = prepare_import_files_request(
818849
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
819850
paths=[path],
820-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
821-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
851+
transformation_config=create_transformation_config(),
822852
)
823853
import_files_request_eq(
824854
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_DRIVE_FOLDER
@@ -835,8 +865,7 @@ def test_prepare_import_files_request_drive_folders_with_pdf_parsing(self, path)
835865
request = prepare_import_files_request(
836866
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
837867
paths=[path],
838-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
839-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
868+
transformation_config=create_transformation_config(),
840869
use_advanced_pdf_parsing=True,
841870
)
842871
import_files_request_eq(
@@ -848,8 +877,7 @@ def test_prepare_import_files_request_drive_files(self):
848877
request = prepare_import_files_request(
849878
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
850879
paths=paths,
851-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
852-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
880+
transformation_config=create_transformation_config(),
853881
max_embedding_requests_per_min=800,
854882
)
855883
import_files_request_eq(
@@ -862,8 +890,7 @@ def test_prepare_import_files_request_invalid_drive_path(self):
862890
prepare_import_files_request(
863891
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
864892
paths=paths,
865-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
866-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
893+
transformation_config=create_transformation_config(),
867894
)
868895
e.match("is not a valid Google Drive url")
869896

@@ -873,17 +900,15 @@ def test_prepare_import_files_request_invalid_path(self):
873900
prepare_import_files_request(
874901
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
875902
paths=paths,
876-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
877-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
903+
transformation_config=create_transformation_config(),
878904
)
879905
e.match("path must be a Google Cloud Storage uri or a Google Drive url")
880906

881907
def test_prepare_import_files_request_slack_source(self):
882908
request = prepare_import_files_request(
883909
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
884910
source=test_rag_constants_preview.TEST_SLACK_SOURCE,
885-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
886-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
911+
transformation_config=create_transformation_config(),
887912
)
888913
import_files_request_eq(
889914
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SLACK_SOURCE
@@ -893,8 +918,7 @@ def test_prepare_import_files_request_jira_source(self):
893918
request = prepare_import_files_request(
894919
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
895920
source=test_rag_constants_preview.TEST_JIRA_SOURCE,
896-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
897-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
921+
transformation_config=create_transformation_config(),
898922
)
899923
import_files_request_eq(
900924
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_JIRA_SOURCE
@@ -904,8 +928,7 @@ def test_prepare_import_files_request_sharepoint_source(self):
904928
request = prepare_import_files_request(
905929
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
906930
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE,
907-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
908-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
931+
transformation_config=create_transformation_config(),
909932
)
910933
import_files_request_eq(
911934
request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE
@@ -916,8 +939,7 @@ def test_prepare_import_files_request_sharepoint_source_2_drives(self):
916939
prepare_import_files_request(
917940
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
918941
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_DRIVES,
919-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
920-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
942+
transformation_config=create_transformation_config(),
921943
)
922944
e.match("drive_name and drive_id cannot both be set.")
923945

@@ -926,8 +948,7 @@ def test_prepare_import_files_request_sharepoint_source_2_folders(self):
926948
prepare_import_files_request(
927949
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
928950
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_FOLDERS,
929-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
930-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
951+
transformation_config=create_transformation_config(),
931952
)
932953
e.match("sharepoint_folder_path and sharepoint_folder_id cannot both be set.")
933954

@@ -936,17 +957,15 @@ def test_prepare_import_files_request_sharepoint_source_no_drives(self):
936957
prepare_import_files_request(
937958
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
938959
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_DRIVES,
939-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
940-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
960+
transformation_config=create_transformation_config(),
941961
)
942962
e.match("Either drive_name and drive_id must be set.")
943963

944964
def test_prepare_import_files_request_sharepoint_source_no_folders(self):
945965
request = prepare_import_files_request(
946966
corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME,
947967
source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_FOLDERS,
948-
chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE,
949-
chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP,
968+
transformation_config=create_transformation_config(),
950969
)
951970
import_files_request_eq(
952971
request,

tests/unit/vertex_rag/test_rag_retrieval_preview.py

+22
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,28 @@ def test_retrieval_query_rag_corpora_config_success(self):
141141
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
142142
)
143143

144+
@pytest.mark.usefixtures("retrieve_contexts_mock")
145+
def test_retrieval_query_rag_corpora_config_rank_service_success(self):
146+
response = rag.retrieval_query(
147+
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
148+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
149+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE,
150+
)
151+
retrieve_contexts_eq(
152+
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
153+
)
154+
155+
@pytest.mark.usefixtures("retrieve_contexts_mock")
156+
def test_retrieval_query_rag_corpora_config_llm_ranker_success(self):
157+
response = rag.retrieval_query(
158+
rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID],
159+
text=test_rag_constants_preview.TEST_QUERY_TEXT,
160+
rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER,
161+
)
162+
retrieve_contexts_eq(
163+
response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE
164+
)
165+
144166
@pytest.mark.usefixtures("rag_client_mock_exception")
145167
def test_retrieval_query_failure(self):
146168
with pytest.raises(RuntimeError) as e:

0 commit comments

Comments
 (0)