Skip to content

Commit f5043a6

Browse files
fthoelecopybara-github
authored andcommitted
feat: Add the possibility to create multimodal datasets without explicitly specifying a bigquery dataset/table.
PiperOrigin-RevId: 744018606
1 parent 0b520cd commit f5043a6

File tree

2 files changed

+209
-119
lines changed

2 files changed

+209
-119
lines changed

google/cloud/aiplatform/preview/datasets.py

+144-52
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import dataclasses
1919
from typing import Dict, List, Optional, Tuple
20+
import uuid
2021

2122
from google.auth import credentials as auth_credentials
2223
from google.cloud import storage
@@ -41,7 +42,8 @@
4142
_MULTIMODAL_METADATA_SCHEMA_URI = (
4243
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
4344
)
44-
45+
_DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets"
46+
_DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset"
4547
_INPUT_CONFIG_FIELD = "inputConfig"
4648
_BIGQUERY_SOURCE_FIELD = "bigquerySource"
4749
_URI_FIELD = "uri"
@@ -147,6 +149,37 @@ def _normalize_and_validate_table_id(
147149
return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}"
148150

149151

152+
def _create_default_bigquery_dataset_if_not_exists(
153+
*,
154+
project: Optional[str] = None,
155+
location: Optional[str] = None,
156+
credentials: Optional[auth_credentials.Credentials] = None,
157+
) -> str:
158+
# Loading bigquery lazily to avoid auto-loading it when importing vertexai
159+
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
160+
161+
if not project:
162+
project = initializer.global_config.project
163+
if not location:
164+
location = initializer.global_config.location
165+
if not credentials:
166+
credentials = initializer.global_config.credentials
167+
168+
bigquery_client = bigquery.Client(project=project, credentials=credentials)
169+
location_str = location.lower().replace("-", "_")
170+
dataset_id = bigquery.DatasetReference(
171+
project, f"{_DEFAULT_BQ_DATASET_PREFIX}_{location_str}"
172+
)
173+
dataset = bigquery.Dataset(dataset_ref=dataset_id)
174+
dataset.location = location
175+
bigquery_client.create_dataset(dataset, exists_ok=True)
176+
return f"{dataset_id.project}.{dataset_id.dataset_id}"
177+
178+
179+
def _generate_target_table_id(dataset_id: str):
180+
return f"{dataset_id}.{_DEFAULT_BQ_TABLE_PREFIX}_{str(uuid.uuid4())}"
181+
182+
150183
class GeminiExample:
151184
"""A class representing a Gemini example."""
152185

@@ -610,7 +643,7 @@ def from_pandas(
610643
cls,
611644
*,
612645
dataframe: pandas.DataFrame,
613-
target_table_id: str,
646+
target_table_id: Optional[str] = None,
614647
display_name: Optional[str] = None,
615648
project: Optional[str] = None,
616649
location: Optional[str] = None,
@@ -625,12 +658,14 @@ def from_pandas(
625658
dataframe (pandas.DataFrame):
626659
The pandas dataframe to be used for the created dataset.
627660
target_table_id (str):
628-
The BigQuery table id where the dataframe will be uploaded. The
629-
table id can be in the format of "dataset.table" or
630-
"project.dataset.table". If a table already exists with the
661+
Optional. The BigQuery table id where the dataframe will be
662+
uploaded. The table id can be in the format of "dataset.table"
663+
or "project.dataset.table". If a table already exists with the
631664
given table id, it will be overwritten. Note that the BigQuery
632665
dataset must already exist and be in the same location as the
633-
multimodal dataset.
666+
multimodal dataset. If not provided, a generated table id will
667+
be created in the `vertex_datasets` dataset (e.g.
668+
`project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
634669
display_name (str):
635670
Optional. The user-defined name of the dataset. The name can be
636671
up to 128 characters long and can consist of any UTF-8
@@ -667,21 +702,43 @@ def from_pandas(
667702
The created multimodal dataset.
668703
"""
669704
bigframes = _try_import_bigframes()
670-
# TODO(b/400355374): `table_id` should be optional, and if not provided,
671-
# we generate a random table id. Also, check if we can use a default
672-
# dataset that's created from the SDK.
673-
target_table_id = _normalize_and_validate_table_id(
674-
table_id=target_table_id,
675-
project=project,
676-
vertex_location=location,
705+
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
706+
707+
if not project:
708+
project = initializer.global_config.project
709+
if not location:
710+
location = initializer.global_config.location
711+
if not credentials:
712+
credentials = initializer.global_config.credentials
713+
714+
if target_table_id:
715+
target_table_id = _normalize_and_validate_table_id(
716+
table_id=target_table_id,
717+
project=project,
718+
vertex_location=location,
719+
credentials=credentials,
720+
)
721+
else:
722+
dataset_id = _create_default_bigquery_dataset_if_not_exists(
723+
project=project, location=location, credentials=credentials
724+
)
725+
target_table_id = _generate_target_table_id(dataset_id)
726+
727+
session_options = bigframes.BigQueryOptions(
677728
credentials=credentials,
729+
project=project,
730+
location=location,
678731
)
679-
680-
temp_bigframes_df = bigframes.pandas.read_pandas(dataframe)
681-
temp_bigframes_df.to_gbq(
682-
destination_table=target_table_id,
683-
if_exists="replace",
732+
with bigframes.connect(session_options) as session:
733+
temp_bigframes_df = session.read_pandas(dataframe)
734+
temp_table_id = temp_bigframes_df.to_gbq()
735+
client = bigquery.Client(project=project, credentials=credentials)
736+
copy_job = client.copy_table(
737+
sources=temp_table_id,
738+
destination=target_table_id,
684739
)
740+
copy_job.result()
741+
685742
bigquery_uri = f"bq://{target_table_id}"
686743
return cls._create_from_bigquery(
687744
bigquery_uri=bigquery_uri,
@@ -700,7 +757,7 @@ def from_bigframes(
700757
cls,
701758
*,
702759
dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821
703-
target_table_id: str,
760+
target_table_id: Optional[str] = None,
704761
display_name: Optional[str] = None,
705762
project: Optional[str] = None,
706763
location: Optional[str] = None,
@@ -716,12 +773,14 @@ def from_bigframes(
716773
The BigFrames dataframe that will be used for the created
717774
dataset.
718775
target_table_id (str):
719-
The BigQuery table id where the dataframe will be uploaded. The
720-
table id can be in the format of "dataset.table" or
721-
"project.dataset.table". If a table already exists with the
776+
Optional. The BigQuery table id where the dataframe will be
777+
uploaded. The table id can be in the format of "dataset.table"
778+
or "project.dataset.table". If a table already exists with the
722779
given table id, it will be overwritten. Note that the BigQuery
723780
dataset must already exist and be in the same location as the
724-
multimodal dataset.
781+
multimodal dataset. If not provided, a generated table id will
782+
be created in the `vertex_datasets` dataset (e.g.
783+
`project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
725784
display_name (str):
726785
Optional. The user-defined name of the dataset. The name can be
727786
up to 128 characters long and can consist of any UTF-8
@@ -756,19 +815,32 @@ def from_bigframes(
756815
Returns:
757816
The created multimodal dataset.
758817
"""
759-
# TODO(b/400355374): `table_id` should be optional, and if not provided,
760-
# we generate a random table id. Also, check if we can use a default
761-
# dataset that's created from the SDK.
762-
target_table_id = _normalize_and_validate_table_id(
763-
table_id=target_table_id,
764-
project=project,
765-
vertex_location=location,
766-
credentials=credentials,
767-
)
768-
dataframe.to_gbq(
769-
destination_table=target_table_id,
770-
if_exists="replace",
818+
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
819+
820+
if target_table_id:
821+
target_table_id = _normalize_and_validate_table_id(
822+
table_id=target_table_id,
823+
project=project,
824+
vertex_location=location,
825+
credentials=credentials,
826+
)
827+
else:
828+
dataset_id = _create_default_bigquery_dataset_if_not_exists(
829+
project=project, location=location, credentials=credentials
830+
)
831+
target_table_id = _generate_target_table_id(dataset_id)
832+
833+
if not project:
834+
project = initializer.global_config.project
835+
836+
temp_table_id = dataframe.to_gbq()
837+
client = bigquery.Client(project=project, credentials=credentials)
838+
copy_job = client.copy_table(
839+
sources=temp_table_id,
840+
destination=target_table_id,
771841
)
842+
copy_job.result()
843+
772844
bigquery_uri = f"bq://{target_table_id}"
773845
return cls._create_from_bigquery(
774846
bigquery_uri=bigquery_uri,
@@ -787,7 +859,7 @@ def from_gemini_request_jsonl(
787859
cls,
788860
*,
789861
gcs_uri: str,
790-
target_table_id: str,
862+
target_table_id: Optional[str] = None,
791863
display_name: Optional[str] = None,
792864
project: Optional[str] = None,
793865
location: Optional[str] = None,
@@ -808,11 +880,14 @@ def from_gemini_request_jsonl(
808880
The Google Cloud Storage URI of the JSONL file to import.
809881
For example, 'gs://my-bucket/path/to/data.jsonl'
810882
target_table_id (str):
811-
The BigQuery table id where the dataframe will be uploaded. The
812-
table id can be in the format of "dataset.table" or
813-
"project.dataset.table". If a table already exists with the
883+
Optional. The BigQuery table id where the dataframe will be
884+
uploaded. The table id can be in the format of "dataset.table"
885+
or "project.dataset.table". If a table already exists with the
814886
given table id, it will be overwritten. Note that the BigQuery
815-
dataset must already exist.
887+
dataset must already exist and be in the same location as the
888+
multimodal dataset. If not provided, a generated table id will
889+
be created in the `vertex_datasets` dataset (e.g.
890+
`project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
816891
display_name (str):
817892
Optional. The user-defined name of the dataset. The name can be
818893
up to 128 characters long and can consist of any UTF-8
@@ -848,14 +923,23 @@ def from_gemini_request_jsonl(
848923
The created multimodal dataset.
849924
"""
850925
bigframes = _try_import_bigframes()
926+
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
927+
851928
if not project:
852929
project = initializer.global_config.project
853-
# TODO(b/400355374): `table_id` should be optional, and if not provided,
854-
# we generate a random table id. Also, check if we can use a default
855-
# dataset that's created from the SDK.
856-
target_table_id = _normalize_and_validate_table_id(
857-
table_id=target_table_id, project=project
858-
)
930+
931+
if target_table_id:
932+
target_table_id = _normalize_and_validate_table_id(
933+
table_id=target_table_id,
934+
project=project,
935+
vertex_location=location,
936+
credentials=credentials,
937+
)
938+
else:
939+
dataset_id = _create_default_bigquery_dataset_if_not_exists(
940+
project=project, location=location, credentials=credentials
941+
)
942+
target_table_id = _generate_target_table_id(dataset_id)
859943

860944
gcs_uri_prefix = "gs://"
861945
if gcs_uri.startswith(gcs_uri_prefix):
@@ -877,13 +961,21 @@ def from_gemini_request_jsonl(
877961
lines = [line.strip() for line in jsonl_string.splitlines() if line.strip()]
878962
df = pandas.DataFrame(lines, columns=[request_column_name])
879963

880-
temp_bigframes_df = bigframes.pandas.read_pandas(df)
881-
temp_bigframes_df[request_column_name] = bigframes.bigquery.parse_json(
882-
temp_bigframes_df[request_column_name]
964+
session_options = bigframes.BigQueryOptions(
965+
credentials=credentials,
966+
project=project,
967+
location=location,
883968
)
884-
temp_bigframes_df.to_gbq(
885-
destination_table=target_table_id,
886-
if_exists="replace",
969+
with bigframes.connect(session_options) as session:
970+
temp_bigframes_df = session.read_pandas(df)
971+
temp_bigframes_df[request_column_name] = bigframes.bigquery.parse_json(
972+
temp_bigframes_df[request_column_name]
973+
)
974+
temp_table_id = temp_bigframes_df.to_gbq()
975+
client = bigquery.Client(project=project, credentials=credentials)
976+
client.copy_table(
977+
sources=temp_table_id,
978+
destination=target_table_id,
887979
)
888980

889981
bigquery_uri = f"bq://{target_table_id}"

0 commit comments

Comments
 (0)