Skip to content

Commit 184cca5

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Propagating import result sink correctly in the vertexai sdk.
PiperOrigin-RevId: 741317609
1 parent a0b6919 commit 184cca5

File tree

4 files changed

+99
-4
lines changed

4 files changed

+99
-4
lines changed

tests/unit/vertex_rag/test_rag_constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@
209209
),
210210
),
211211
)
212+
TEST_IMPORT_RESULT_GCS_SINK = "gs://test-bucket/test-object.ndjson"
213+
TEST_IMPORT_RESULT_BIGQUERY_SINK = "bq://test-project.test_dataset.test_table"
212214
# GCS
213215
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig(
214216
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,

tests/unit/vertex_rag/test_rag_data.py

+54
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,14 @@ def import_files_request_eq(returned_request, expected_request):
276276
returned_request.import_rag_files_config.rag_file_transformation_config
277277
== expected_request.import_rag_files_config.rag_file_transformation_config
278278
)
279+
assert (
280+
returned_request.import_rag_files_config.import_result_gcs_sink
281+
== expected_request.import_rag_files_config.import_result_gcs_sink
282+
)
283+
assert (
284+
returned_request.import_rag_files_config.import_result_bigquery_sink
285+
== expected_request.import_rag_files_config.import_result_bigquery_sink
286+
)
279287

280288

281289
@pytest.mark.usefixtures("google_auth_mock")
@@ -517,6 +525,26 @@ def test_import_files(self, import_files_mock):
517525

518526
assert response.imported_rag_files_count == 2
519527

528+
def test_import_files_with_import_result_gcs_sink(self, import_files_mock):
529+
response = rag.import_files(
530+
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
531+
paths=[test_rag_constants.TEST_GCS_PATH],
532+
import_result_sink=test_rag_constants.TEST_IMPORT_RESULT_GCS_SINK,
533+
)
534+
import_files_mock.assert_called_once()
535+
536+
assert response.imported_rag_files_count == 2
537+
538+
def test_import_files_with_import_result_bigquery_sink(self, import_files_mock):
539+
response = rag.import_files(
540+
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
541+
paths=[test_rag_constants.TEST_GCS_PATH],
542+
import_result_sink=test_rag_constants.TEST_IMPORT_RESULT_BIGQUERY_SINK,
543+
)
544+
import_files_mock.assert_called_once()
545+
546+
assert response.imported_rag_files_count == 2
547+
520548
@pytest.mark.usefixtures("rag_data_client_mock_exception")
521549
def test_import_files_failure(self):
522550
with pytest.raises(RuntimeError) as e:
@@ -536,6 +564,32 @@ async def test_import_files_async(self, import_files_async_mock):
536564

537565
assert response.result().imported_rag_files_count == 2
538566

567+
@pytest.mark.asyncio
568+
async def test_import_files_with_import_result_gcs_sink_async(
569+
self, import_files_async_mock
570+
):
571+
response = await rag.import_files_async(
572+
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
573+
paths=[test_rag_constants.TEST_GCS_PATH],
574+
import_result_sink=test_rag_constants.TEST_IMPORT_RESULT_GCS_SINK,
575+
)
576+
import_files_async_mock.assert_called_once()
577+
578+
assert response.result().imported_rag_files_count == 2
579+
580+
@pytest.mark.asyncio
581+
async def test_import_files_with_import_result_bigquery_sink_async(
582+
self, import_files_async_mock
583+
):
584+
response = await rag.import_files_async(
585+
corpus_name=test_rag_constants.TEST_RAG_CORPUS_RESOURCE_NAME,
586+
paths=[test_rag_constants.TEST_GCS_PATH],
587+
import_result_sink=test_rag_constants.TEST_IMPORT_RESULT_BIGQUERY_SINK,
588+
)
589+
import_files_async_mock.assert_called_once()
590+
591+
assert response.result().imported_rag_files_count == 2
592+
539593
@pytest.mark.asyncio
540594
@pytest.mark.usefixtures("rag_data_async_client_mock_exception")
541595
async def test_import_files_async_failure(self):

vertexai/rag/rag_data.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def import_files(
395395
transformation_config: Optional[TransformationConfig] = None,
396396
timeout: int = 600,
397397
max_embedding_requests_per_min: int = 1000,
398+
import_result_sink: Optional[str] = None,
398399
partial_failures_sink: Optional[str] = None,
399400
parser: Optional[LayoutParserConfig] = None,
400401
) -> ImportRagFilesResponse:
@@ -509,8 +510,17 @@ def import_files(
509510
here. If unspecified, a default value of 1,000
510511
QPM would be used.
511512
timeout: Default is 600 seconds.
512-
partial_failures_sink: Either a GCS path to store partial failures or a
513-
BigQuery table to store partial failures. The format is
513+
import_result_sink: Either a GCS path to store import results or a
514+
BigQuery table to store import results. The format is
515+
"gs://my-bucket/my/object.ndjson" for GCS or
516+
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
517+
object cannot be used. However, the BigQuery table may or may not
518+
exist - if it does not exist, it will be created. If it does exist,
519+
the schema will be checked and the import results will be appended
520+
to the table.
521+
partial_failures_sink: Deprecated. Prefer to use `import_result_sink`.
522+
Either a GCS path to store partial failures or a BigQuery table to
523+
store partial failures. The format is
514524
"gs://my-bucket/my/object.ndjson" for GCS or
515525
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
516526
object cannot be used. However, the BigQuery table may or may not
@@ -534,6 +544,7 @@ def import_files(
534544
source=source,
535545
transformation_config=transformation_config,
536546
max_embedding_requests_per_min=max_embedding_requests_per_min,
547+
import_result_sink=import_result_sink,
537548
partial_failures_sink=partial_failures_sink,
538549
parser=parser,
539550
)
@@ -552,6 +563,7 @@ async def import_files_async(
552563
source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None,
553564
transformation_config: Optional[TransformationConfig] = None,
554565
max_embedding_requests_per_min: int = 1000,
566+
import_result_sink: Optional[str] = None,
555567
partial_failures_sink: Optional[str] = None,
556568
parser: Optional[LayoutParserConfig] = None,
557569
) -> operation_async.AsyncOperation:
@@ -666,8 +678,17 @@ async def import_files_async(
666678
page on the project to set an appropriate value
667679
here. If unspecified, a default value of 1,000
668680
QPM would be used.
669-
partial_failures_sink: Either a GCS path to store partial failures or a
670-
BigQuery table to store partial failures. The format is
681+
import_result_sink: Either a GCS path to store import results or a
682+
BigQuery table to store import results. The format is
683+
"gs://my-bucket/my/object.ndjson" for GCS or
684+
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
685+
object cannot be used. However, the BigQuery table may or may not
686+
exist - if it does not exist, it will be created. If it does exist,
687+
the schema will be checked and the import results will be appended
688+
to the table.
689+
partial_failures_sink: Deprecated. Prefer to use `import_result_sink`.
690+
Either a GCS path to store partial failures or a BigQuery table to
691+
store partial failures. The format is
671692
"gs://my-bucket/my/object.ndjson" for GCS or
672693
"bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS
673694
object cannot be used. However, the BigQuery table may or may not
@@ -691,6 +712,7 @@ async def import_files_async(
691712
source=source,
692713
transformation_config=transformation_config,
693714
max_embedding_requests_per_min=max_embedding_requests_per_min,
715+
import_result_sink=import_result_sink,
694716
partial_failures_sink=partial_failures_sink,
695717
parser=parser,
696718
)

vertexai/rag/utils/_gapic_utils.py

+17
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def prepare_import_files_request(
360360
source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None,
361361
transformation_config: Optional[TransformationConfig] = None,
362362
max_embedding_requests_per_min: int = 1000,
363+
import_result_sink: Optional[str] = None,
363364
partial_failures_sink: Optional[str] = None,
364365
parser: Optional[LayoutParserConfig] = None,
365366
) -> ImportRagFilesRequest:
@@ -407,6 +408,22 @@ def prepare_import_files_request(
407408
max_embedding_requests_per_min=max_embedding_requests_per_min,
408409
)
409410

411+
import_result_sink = import_result_sink or partial_failures_sink
412+
413+
if import_result_sink is not None:
414+
if import_result_sink.startswith("gs://"):
415+
import_rag_files_config.partial_failure_gcs_sink.output_uri_prefix = (
416+
import_result_sink
417+
)
418+
elif import_result_sink.startswith("bq://"):
419+
import_rag_files_config.partial_failure_bigquery_sink.output_uri = (
420+
import_result_sink
421+
)
422+
else:
423+
raise ValueError(
424+
"import_result_sink must be a GCS path or a BigQuery table."
425+
)
426+
410427
if source is not None:
411428
gapic_source = convert_source_for_rag_import(source)
412429
if isinstance(gapic_source, GapicSlackSource):

0 commit comments

Comments
 (0)