Skip to content

Commit 98459aa

Browse files
fthoelecopybara-github
authored andcommitted
feat: Add validation of the BigQuery location when creating a MultimodalDataset
PiperOrigin-RevId: 741515869
1 parent 184cca5 commit 98459aa

File tree

2 files changed

+133
-26
lines changed

2 files changed

+133
-26
lines changed

google/cloud/aiplatform/preview/datasets.py

+55-15
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,46 @@ def _get_metadata_for_bq(
8888
return json_format.ParseDict(input_config, struct_pb2.Value())
8989

9090

91-
def _normalize_table_id(*, table_id: str, project: str):
92-
if table_id.count(".") == 1:
93-
# table_id has the "dataset.table" format, prepend the project
94-
return f"{project}.{table_id}"
95-
elif table_id.count(".") != 2:
96-
raise ValueError(f"invalid table id: {table_id}")
97-
return table_id
91+
def _normalize_and_validate_table_id(
92+
*,
93+
table_id: str,
94+
project: Optional[str] = None,
95+
vertex_location: Optional[str] = None,
96+
credentials: Optional[auth_credentials.Credentials] = None,
97+
):
98+
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
99+
100+
if not project:
101+
project = initializer.global_config.project
102+
if not vertex_location:
103+
vertex_location = initializer.global_config.location
104+
if not credentials:
105+
credentials = initializer.global_config.credentials
106+
107+
table_ref = bigquery.TableReference.from_string(table_id, default_project=project)
108+
if table_ref.project != project:
109+
raise ValueError(
110+
f"The BigQuery table "
111+
f"`{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}`"
112+
" must be in the same project as the multimodal dataset."
113+
f" The multimodal dataset is in `{project}`, but the BigQuery table"
114+
f" is in `{table_ref.project}`."
115+
)
116+
117+
dataset_ref = bigquery.DatasetReference(
118+
project=table_ref.project, dataset_id=table_ref.dataset_id
119+
)
120+
client = bigquery.Client(project=project, credentials=credentials)
121+
bq_dataset = client.get_dataset(dataset_ref=dataset_ref)
122+
if bq_dataset.location != vertex_location:
123+
raise ValueError(
124+
f"The BigQuery dataset"
125+
f" `{dataset_ref.project}.{dataset_ref.dataset_id}` must be in the"
126+
" same location as the multimodal dataset. The multimodal dataset"
127+
f" is in `{vertex_location}`, but the BigQuery dataset is in"
128+
f" `{bq_dataset.location}`."
129+
)
130+
return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}"
98131

99132

100133
class GeminiExample:
@@ -577,7 +610,8 @@ def from_pandas(
577610
table id can be in the format of "dataset.table" or
578611
"project.dataset.table". If a table already exists with the
579612
given table id, it will be overwritten. Note that the BigQuery
580-
dataset must already exist.
613+
dataset must already exist and be in the same location as the
614+
multimodal dataset.
581615
display_name (str):
582616
Optional. The user-defined name of the dataset. The name can be
583617
up to 128 characters long and can consist of any UTF-8
@@ -614,12 +648,15 @@ def from_pandas(
614648
The created multimodal dataset.
615649
"""
616650
bigframes = _try_import_bigframes()
617-
if not project:
618-
project = initializer.global_config.project
619651
# TODO(b/400355374): `table_id` should be optional, and if not provided,
620652
# we generate a random table id. Also, check if we can use a default
621653
# dataset that's created from the SDK.
622-
target_table_id = _normalize_table_id(table_id=target_table_id, project=project)
654+
target_table_id = _normalize_and_validate_table_id(
655+
table_id=target_table_id,
656+
project=project,
657+
vertex_location=location,
658+
credentials=credentials,
659+
)
623660

624661
temp_bigframes_df = bigframes.pandas.read_pandas(dataframe)
625662
temp_bigframes_df.to_gbq(
@@ -662,7 +699,8 @@ def from_bigframes(
662699
table id can be in the format of "dataset.table" or
663700
"project.dataset.table". If a table already exists with the
664701
given table id, it will be overwritten. Note that the BigQuery
665-
dataset must already exist.
702+
dataset must already exist and be in the same location as the
703+
multimodal dataset.
666704
display_name (str):
667705
Optional. The user-defined name of the dataset. The name can be
668706
up to 128 characters long and can consist of any UTF-8
@@ -697,12 +735,14 @@ def from_bigframes(
697735
Returns:
698736
The created multimodal dataset.
699737
"""
700-
project_id = project or initializer.global_config.project
701738
# TODO(b/400355374): `table_id` should be optional, and if not provided,
702739
# we generate a random table id. Also, check if we can use a default
703740
# dataset that's created from the SDK.
704-
target_table_id = _normalize_table_id(
705-
table_id=target_table_id, project=project_id
741+
target_table_id = _normalize_and_validate_table_id(
742+
table_id=target_table_id,
743+
project=project,
744+
vertex_location=location,
745+
credentials=credentials,
706746
)
707747
dataframe.to_gbq(
708748
destination_table=target_table_id,

tests/unit/aiplatform/test_multimodal_datasets.py

+78-11
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google import auth
2222
from google.api_core import operation
2323
from google.auth import credentials as auth_credentials
24+
from google.cloud import bigquery
2425
from google.cloud import aiplatform
2526
from google.cloud.aiplatform import base
2627
from google.cloud.aiplatform import initializer
@@ -42,6 +43,7 @@
4243

4344
_TEST_PROJECT = "test-project"
4445
_TEST_LOCATION = "us-central1"
46+
_TEST_ALTERNATE_LOCATION = "europe-west6"
4547
_TEST_ID = "1028944691210842416"
4648
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
4749
_TEST_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/datasets/{_TEST_ID}"
@@ -53,6 +55,8 @@
5355
)
5456

5557
_TEST_SOURCE_URI_BQ = "bq://my-project.my-dataset.table"
58+
_TEST_TARGET_BQ_DATASET = f"{_TEST_PROJECT}.target-dataset"
59+
_TEST_TARGET_BQ_TABLE = f"{_TEST_TARGET_BQ_DATASET}.target-table"
5660
_TEST_DISPLAY_NAME = "my_dataset_1234"
5761
_TEST_METADATA_SCHEMA_URI_MULTIMODAL = (
5862
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
@@ -168,6 +172,24 @@ def bigframes_import_mock():
168172
del sys.modules["bigframes.pandas"]
169173

170174

175+
@pytest.fixture
176+
def get_bq_dataset_mock():
177+
with mock.patch.object(bigquery.Client, "get_dataset") as get_bq_dataset_mock:
178+
bq_dataset = mock.Mock()
179+
bq_dataset.location = _TEST_LOCATION
180+
get_bq_dataset_mock.return_value = bq_dataset
181+
yield get_bq_dataset_mock
182+
183+
184+
@pytest.fixture
185+
def get_bq_dataset_alternate_location_mock():
186+
with mock.patch.object(bigquery.Client, "get_dataset") as get_bq_dataset_mock:
187+
bq_dataset = mock.Mock()
188+
bq_dataset.location = _TEST_ALTERNATE_LOCATION
189+
get_bq_dataset_mock.return_value = bq_dataset
190+
yield get_bq_dataset_mock
191+
192+
171193
@pytest.fixture
172194
def update_dataset_with_template_config_mock():
173195
with mock.patch.object(
@@ -259,7 +281,7 @@ def test_create_dataset_from_bigquery(self, create_dataset_mock, sync):
259281
)
260282

261283
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
262-
@pytest.mark.usefixtures("get_dataset_mock")
284+
@pytest.mark.usefixtures("get_dataset_mock", "get_bq_dataset_mock")
263285
def test_create_dataset_from_pandas(
264286
self, create_dataset_mock, bigframes_import_mock
265287
):
@@ -273,55 +295,100 @@ def test_create_dataset_from_pandas(
273295
"answer": ["answer"],
274296
}
275297
)
276-
bq_table = "my-project.my-dataset.my-table"
277298
ummd.MultimodalDataset.from_pandas(
278299
dataframe=dataframe,
279-
target_table_id=bq_table,
300+
target_table_id=_TEST_TARGET_BQ_TABLE,
280301
display_name=_TEST_DISPLAY_NAME,
281302
)
282303
expected_dataset = gca_dataset.Dataset(
283304
display_name=_TEST_DISPLAY_NAME,
284305
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_MULTIMODAL,
285-
metadata={"inputConfig": {"bigquerySource": {"uri": f"bq://{bq_table}"}}},
306+
metadata={
307+
"inputConfig": {
308+
"bigquerySource": {"uri": f"bq://{_TEST_TARGET_BQ_TABLE}"}
309+
}
310+
},
286311
)
287312
create_dataset_mock.assert_called_once_with(
288313
dataset=expected_dataset,
289314
parent=_TEST_PARENT,
290315
timeout=None,
291316
)
292317
bigframes_mock.to_gbq.assert_called_once_with(
293-
destination_table=bq_table,
318+
destination_table=_TEST_TARGET_BQ_TABLE,
294319
if_exists="replace",
295320
)
296321

297322
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
298-
@pytest.mark.usefixtures("bigframes_import_mock")
299-
@pytest.mark.usefixtures("get_dataset_mock")
323+
@pytest.mark.usefixtures(
324+
"bigframes_import_mock", "get_dataset_mock", "get_bq_dataset_mock"
325+
)
300326
def test_create_dataset_from_bigframes(self, create_dataset_mock):
301327
aiplatform.init(project=_TEST_PROJECT)
302328
bigframes_df = mock.Mock()
303-
bq_table = "my-project.my-dataset.my-table"
304329
ummd.MultimodalDataset.from_bigframes(
305330
dataframe=bigframes_df,
306-
target_table_id=bq_table,
331+
target_table_id=_TEST_TARGET_BQ_TABLE,
307332
display_name=_TEST_DISPLAY_NAME,
308333
)
309334

310335
bigframes_df.to_gbq.assert_called_once_with(
311-
destination_table=bq_table,
336+
destination_table=_TEST_TARGET_BQ_TABLE,
312337
if_exists="replace",
313338
)
314339
expected_dataset = gca_dataset.Dataset(
315340
display_name=_TEST_DISPLAY_NAME,
316341
metadata_schema_uri=_TEST_METADATA_SCHEMA_URI_MULTIMODAL,
317-
metadata={"inputConfig": {"bigquerySource": {"uri": f"bq://{bq_table}"}}},
342+
metadata={
343+
"inputConfig": {
344+
"bigquerySource": {"uri": f"bq://{_TEST_TARGET_BQ_TABLE}"}
345+
}
346+
},
318347
)
319348
create_dataset_mock.assert_called_once_with(
320349
dataset=expected_dataset,
321350
parent=_TEST_PARENT,
322351
timeout=None,
323352
)
324353

354+
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
355+
@pytest.mark.usefixtures("bigframes_import_mock")
356+
def test_create_dataset_from_bigframes_different_project_throws_error(self):
357+
aiplatform.init(project=_TEST_PROJECT)
358+
bigframes_df = mock.Mock()
359+
with pytest.raises(ValueError):
360+
ummd.MultimodalDataset.from_bigframes(
361+
dataframe=bigframes_df,
362+
target_table_id="another_project.dataset.table",
363+
display_name=_TEST_DISPLAY_NAME,
364+
)
365+
366+
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
367+
@pytest.mark.usefixtures(
368+
"bigframes_import_mock", "get_bq_dataset_alternate_location_mock"
369+
)
370+
def test_create_dataset_from_bigframes_different_location_throws_error(self):
371+
aiplatform.init(project=_TEST_PROJECT)
372+
bigframes_df = mock.Mock()
373+
with pytest.raises(ValueError):
374+
ummd.MultimodalDataset.from_bigframes(
375+
dataframe=bigframes_df,
376+
target_table_id=_TEST_TARGET_BQ_TABLE,
377+
display_name=_TEST_DISPLAY_NAME,
378+
)
379+
380+
@pytest.mark.skip(reason="flaky with other tests mocking bigframes")
381+
@pytest.mark.usefixtures("bigframes_import_mock")
382+
def test_create_dataset_from_bigframes_invalid_target_table_id_throws_error(self):
383+
aiplatform.init(project=_TEST_PROJECT)
384+
bigframes_df = mock.Mock()
385+
with pytest.raises(ValueError):
386+
ummd.MultimodalDataset.from_bigframes(
387+
dataframe=bigframes_df,
388+
target_table_id="invalid-table",
389+
display_name=_TEST_DISPLAY_NAME,
390+
)
391+
325392
@pytest.mark.usefixtures("get_dataset_mock")
326393
def test_update_dataset(self, update_dataset_mock):
327394
aiplatform.init(project=_TEST_PROJECT)

0 commit comments

Comments
 (0)