Skip to content

Commit 2d7bc32

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: allow table targets in multi-region datasets when creating multimodal datasets
PiperOrigin-RevId: 744717895
1 parent 9f21b73 commit 2d7bc32

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

google/cloud/aiplatform/preview/datasets.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
_GEMINI_TEMPLATE_CONFIG_FIELD = "geminiTemplateConfig"
5252
_PROMPT_URI_FIELD = "promptUri"
5353
_REQUEST_COLUMN_NAME_FIELD = "requestColumnName"
54+
_BQ_MULTIREGIONS = {"us", "eu"}
5455

5556
_LOGGER = base.Logger(__name__)
5657

@@ -107,6 +108,16 @@ def _get_metadata_for_bq(
107108
return json_format.ParseDict(input_config, struct_pb2.Value())
108109

109110

111+
def _bq_dataset_location_allowed(
112+
vertex_location: str, bq_dataset_location: str
113+
) -> bool:
114+
if bq_dataset_location == vertex_location:
115+
return True
116+
if bq_dataset_location in _BQ_MULTIREGIONS:
117+
return vertex_location.startswith(bq_dataset_location)
118+
return False
119+
120+
110121
def _normalize_and_validate_table_id(
111122
*,
112123
table_id: str,
@@ -138,7 +149,7 @@ def _normalize_and_validate_table_id(
138149
)
139150
client = bigquery.Client(project=project, credentials=credentials)
140151
bq_dataset = client.get_dataset(dataset_ref=dataset_ref)
141-
if bq_dataset.location != vertex_location:
152+
if not _bq_dataset_location_allowed(vertex_location, bq_dataset.location):
142153
raise ValueError(
143154
f"The BigQuery dataset"
144155
f" `{dataset_ref.project}.{dataset_ref.dataset_id}` must be in the"

tests/unit/aiplatform/test_multimodal_datasets.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ def bigframes_import_mock():
198198
bigframes_module.bigquery = bbq_module
199199
sys.modules["bigframes"] = bigframes_module
200200

201+
bigframes_module.BigQueryOptions = mock.MagicMock()
202+
bigframes_module.connect = mock.MagicMock()
203+
201204
yield bigframes_module, bpd_module, bbq_module
202205

203206
del sys.modules["bigframes"]
@@ -461,8 +464,6 @@ def test_create_dataset_from_gemini_request_jsonl(
461464
bpd_module.read_pandas = mock.MagicMock()
462465
bbq_module.parse_json = lambda x: x
463466

464-
bf_module.BigQueryOptions = mock.MagicMock()
465-
bf_module.connect = mock.MagicMock()
466467
session_mock = mock.MagicMock()
467468
bf_module.connect.return_value.__enter__.return_value = session_mock
468469

@@ -744,6 +745,53 @@ def test_assemble_request_column_name(self, assemble_data_mock):
744745
)
745746
assert result_table_id == _TEST_ASSEMBLE_DATA_BIGQUERY_DESTINATION[5:]
746747

748+
@pytest.mark.usefixtures("get_dataset_mock")
749+
def test_create_dataset_from_pandas_multiregion_target_table_allowed(
750+
self, create_dataset_mock, bigframes_import_mock, bq_client_mock
751+
):
752+
bq_client_mock.return_value.get_dataset.return_value.location = "us"
753+
754+
_, bpd_module, _ = bigframes_import_mock
755+
756+
bpd_module.read_pandas = lambda x: mock.Mock()
757+
aiplatform.init(project=_TEST_PROJECT)
758+
dataframe = pandas.DataFrame(
759+
{
760+
"question": ["question"],
761+
"answer": ["answer"],
762+
}
763+
)
764+
ummd.MultimodalDataset.from_pandas(
765+
dataframe=dataframe,
766+
target_table_id=_TEST_TARGET_BQ_TABLE,
767+
display_name=_TEST_DISPLAY_NAME,
768+
location="us-central1",
769+
)
770+
create_dataset_mock.assert_called_once()
771+
772+
def test_create_dataset_from_pandas_multiregion_target_table_location_mismatch_throws_error(
773+
self, bigframes_import_mock, bq_client_mock
774+
):
775+
bq_client_mock.return_value.get_dataset.return_value.location = "eu"
776+
777+
_, bpd_module, _ = bigframes_import_mock
778+
779+
bpd_module.read_pandas = lambda x: mock.Mock()
780+
aiplatform.init(project=_TEST_PROJECT)
781+
dataframe = pandas.DataFrame(
782+
{
783+
"question": ["question"],
784+
"answer": ["answer"],
785+
}
786+
)
787+
with pytest.raises(ValueError):
788+
ummd.MultimodalDataset.from_pandas(
789+
dataframe=dataframe,
790+
target_table_id=_TEST_TARGET_BQ_TABLE,
791+
display_name=_TEST_DISPLAY_NAME,
792+
location="us-central1",
793+
)
794+
747795

748796
class TestGeminiExample:
749797
"""Tests for the GeminiExample class."""

0 commit comments

Comments
 (0)