17
17
18
18
import dataclasses
19
19
from typing import Dict , List , Optional , Tuple
20
+ import uuid
20
21
21
22
from google .auth import credentials as auth_credentials
22
23
from google .cloud import storage
41
42
_MULTIMODAL_METADATA_SCHEMA_URI = (
42
43
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
43
44
)
44
-
45
+ _DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets"
46
+ _DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset"
45
47
_INPUT_CONFIG_FIELD = "inputConfig"
46
48
_BIGQUERY_SOURCE_FIELD = "bigquerySource"
47
49
_URI_FIELD = "uri"
@@ -147,6 +149,37 @@ def _normalize_and_validate_table_id(
147
149
return f"{ table_ref .project } .{ table_ref .dataset_id } .{ table_ref .table_id } "
148
150
149
151
152
+ def _create_default_bigquery_dataset_if_not_exists (
153
+ * ,
154
+ project : Optional [str ] = None ,
155
+ location : Optional [str ] = None ,
156
+ credentials : Optional [auth_credentials .Credentials ] = None ,
157
+ ) -> str :
158
+ # Loading bigquery lazily to avoid auto-loading it when importing vertexai
159
+ from google .cloud import bigquery # pylint: disable=g-import-not-at-top
160
+
161
+ if not project :
162
+ project = initializer .global_config .project
163
+ if not location :
164
+ location = initializer .global_config .location
165
+ if not credentials :
166
+ credentials = initializer .global_config .credentials
167
+
168
+ bigquery_client = bigquery .Client (project = project , credentials = credentials )
169
+ location_str = location .lower ().replace ("-" , "_" )
170
+ dataset_id = bigquery .DatasetReference (
171
+ project , f"{ _DEFAULT_BQ_DATASET_PREFIX } _{ location_str } "
172
+ )
173
+ dataset = bigquery .Dataset (dataset_ref = dataset_id )
174
+ dataset .location = location
175
+ bigquery_client .create_dataset (dataset , exists_ok = True )
176
+ return f"{ dataset_id .project } .{ dataset_id .dataset_id } "
177
+
178
+
179
+ def _generate_target_table_id (dataset_id : str ):
180
+ return f"{ dataset_id } .{ _DEFAULT_BQ_TABLE_PREFIX } _{ str (uuid .uuid4 ())} "
181
+
182
+
150
183
class GeminiExample :
151
184
"""A class representing a Gemini example."""
152
185
@@ -610,7 +643,7 @@ def from_pandas(
610
643
cls ,
611
644
* ,
612
645
dataframe : pandas .DataFrame ,
613
- target_table_id : str ,
646
+ target_table_id : Optional [ str ] = None ,
614
647
display_name : Optional [str ] = None ,
615
648
project : Optional [str ] = None ,
616
649
location : Optional [str ] = None ,
@@ -625,12 +658,14 @@ def from_pandas(
625
658
dataframe (pandas.DataFrame):
626
659
The pandas dataframe to be used for the created dataset.
627
660
target_table_id (str):
628
- The BigQuery table id where the dataframe will be uploaded. The
629
- table id can be in the format of "dataset.table" or
630
- "project.dataset.table". If a table already exists with the
661
+ Optional. The BigQuery table id where the dataframe will be
662
+ uploaded. The table id can be in the format of "dataset.table"
663
+ or "project.dataset.table". If a table already exists with the
631
664
given table id, it will be overwritten. Note that the BigQuery
632
665
dataset must already exist and be in the same location as the
633
- multimodal dataset.
666
+ multimodal dataset. If not provided, a generated table id will
667
+ be created in the `vertex_datasets` dataset (e.g.
668
+ `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
634
669
display_name (str):
635
670
Optional. The user-defined name of the dataset. The name can be
636
671
up to 128 characters long and can consist of any UTF-8
@@ -667,21 +702,43 @@ def from_pandas(
667
702
The created multimodal dataset.
668
703
"""
669
704
bigframes = _try_import_bigframes ()
670
- # TODO(b/400355374): `table_id` should be optional, and if not provided,
671
- # we generate a random table id. Also, check if we can use a default
672
- # dataset that's created from the SDK.
673
- target_table_id = _normalize_and_validate_table_id (
674
- table_id = target_table_id ,
675
- project = project ,
676
- vertex_location = location ,
705
+ from google .cloud import bigquery # pylint: disable=g-import-not-at-top
706
+
707
+ if not project :
708
+ project = initializer .global_config .project
709
+ if not location :
710
+ location = initializer .global_config .location
711
+ if not credentials :
712
+ credentials = initializer .global_config .credentials
713
+
714
+ if target_table_id :
715
+ target_table_id = _normalize_and_validate_table_id (
716
+ table_id = target_table_id ,
717
+ project = project ,
718
+ vertex_location = location ,
719
+ credentials = credentials ,
720
+ )
721
+ else :
722
+ dataset_id = _create_default_bigquery_dataset_if_not_exists (
723
+ project = project , location = location , credentials = credentials
724
+ )
725
+ target_table_id = _generate_target_table_id (dataset_id )
726
+
727
+ session_options = bigframes .BigQueryOptions (
677
728
credentials = credentials ,
729
+ project = project ,
730
+ location = location ,
678
731
)
679
-
680
- temp_bigframes_df = bigframes .pandas .read_pandas (dataframe )
681
- temp_bigframes_df .to_gbq (
682
- destination_table = target_table_id ,
683
- if_exists = "replace" ,
732
+ with bigframes .connect (session_options ) as session :
733
+ temp_bigframes_df = session .read_pandas (dataframe )
734
+ temp_table_id = temp_bigframes_df .to_gbq ()
735
+ client = bigquery .Client (project = project , credentials = credentials )
736
+ copy_job = client .copy_table (
737
+ sources = temp_table_id ,
738
+ destination = target_table_id ,
684
739
)
740
+ copy_job .result ()
741
+
685
742
bigquery_uri = f"bq://{ target_table_id } "
686
743
return cls ._create_from_bigquery (
687
744
bigquery_uri = bigquery_uri ,
@@ -700,7 +757,7 @@ def from_bigframes(
700
757
cls ,
701
758
* ,
702
759
dataframe : "bigframes.pandas.DataFrame" , # type: ignore # noqa: F821
703
- target_table_id : str ,
760
+ target_table_id : Optional [ str ] = None ,
704
761
display_name : Optional [str ] = None ,
705
762
project : Optional [str ] = None ,
706
763
location : Optional [str ] = None ,
@@ -716,12 +773,14 @@ def from_bigframes(
716
773
The BigFrames dataframe that will be used for the created
717
774
dataset.
718
775
target_table_id (str):
719
- The BigQuery table id where the dataframe will be uploaded. The
720
- table id can be in the format of "dataset.table" or
721
- "project.dataset.table". If a table already exists with the
776
+ Optional. The BigQuery table id where the dataframe will be
777
+ uploaded. The table id can be in the format of "dataset.table"
778
+ or "project.dataset.table". If a table already exists with the
722
779
given table id, it will be overwritten. Note that the BigQuery
723
780
dataset must already exist and be in the same location as the
724
- multimodal dataset.
781
+ multimodal dataset. If not provided, a generated table id will
782
+ be created in the `vertex_datasets` dataset (e.g.
783
+ `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
725
784
display_name (str):
726
785
Optional. The user-defined name of the dataset. The name can be
727
786
up to 128 characters long and can consist of any UTF-8
@@ -756,19 +815,32 @@ def from_bigframes(
756
815
Returns:
757
816
The created multimodal dataset.
758
817
"""
759
- # TODO(b/400355374): `table_id` should be optional, and if not provided,
760
- # we generate a random table id. Also, check if we can use a default
761
- # dataset that's created from the SDK.
762
- target_table_id = _normalize_and_validate_table_id (
763
- table_id = target_table_id ,
764
- project = project ,
765
- vertex_location = location ,
766
- credentials = credentials ,
767
- )
768
- dataframe .to_gbq (
769
- destination_table = target_table_id ,
770
- if_exists = "replace" ,
818
+ from google .cloud import bigquery # pylint: disable=g-import-not-at-top
819
+
820
+ if target_table_id :
821
+ target_table_id = _normalize_and_validate_table_id (
822
+ table_id = target_table_id ,
823
+ project = project ,
824
+ vertex_location = location ,
825
+ credentials = credentials ,
826
+ )
827
+ else :
828
+ dataset_id = _create_default_bigquery_dataset_if_not_exists (
829
+ project = project , location = location , credentials = credentials
830
+ )
831
+ target_table_id = _generate_target_table_id (dataset_id )
832
+
833
+ if not project :
834
+ project = initializer .global_config .project
835
+
836
+ temp_table_id = dataframe .to_gbq ()
837
+ client = bigquery .Client (project = project , credentials = credentials )
838
+ copy_job = client .copy_table (
839
+ sources = temp_table_id ,
840
+ destination = target_table_id ,
771
841
)
842
+ copy_job .result ()
843
+
772
844
bigquery_uri = f"bq://{ target_table_id } "
773
845
return cls ._create_from_bigquery (
774
846
bigquery_uri = bigquery_uri ,
@@ -787,7 +859,7 @@ def from_gemini_request_jsonl(
787
859
cls ,
788
860
* ,
789
861
gcs_uri : str ,
790
- target_table_id : str ,
862
+ target_table_id : Optional [ str ] = None ,
791
863
display_name : Optional [str ] = None ,
792
864
project : Optional [str ] = None ,
793
865
location : Optional [str ] = None ,
@@ -808,11 +880,14 @@ def from_gemini_request_jsonl(
808
880
The Google Cloud Storage URI of the JSONL file to import.
809
881
For example, 'gs://my-bucket/path/to/data.jsonl'
810
882
target_table_id (str):
811
- The BigQuery table id where the dataframe will be uploaded. The
812
- table id can be in the format of "dataset.table" or
813
- "project.dataset.table". If a table already exists with the
883
+ Optional. The BigQuery table id where the dataframe will be
884
+ uploaded. The table id can be in the format of "dataset.table"
885
+ or "project.dataset.table". If a table already exists with the
814
886
given table id, it will be overwritten. Note that the BigQuery
815
- dataset must already exist.
887
+ dataset must already exist and be in the same location as the
888
+ multimodal dataset. If not provided, a generated table id will
889
+ be created in the `vertex_datasets` dataset (e.g.
890
+ `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`).
816
891
display_name (str):
817
892
Optional. The user-defined name of the dataset. The name can be
818
893
up to 128 characters long and can consist of any UTF-8
@@ -848,14 +923,23 @@ def from_gemini_request_jsonl(
848
923
The created multimodal dataset.
849
924
"""
850
925
bigframes = _try_import_bigframes ()
926
+ from google .cloud import bigquery # pylint: disable=g-import-not-at-top
927
+
851
928
if not project :
852
929
project = initializer .global_config .project
853
- # TODO(b/400355374): `table_id` should be optional, and if not provided,
854
- # we generate a random table id. Also, check if we can use a default
855
- # dataset that's created from the SDK.
856
- target_table_id = _normalize_and_validate_table_id (
857
- table_id = target_table_id , project = project
858
- )
930
+
931
+ if target_table_id :
932
+ target_table_id = _normalize_and_validate_table_id (
933
+ table_id = target_table_id ,
934
+ project = project ,
935
+ vertex_location = location ,
936
+ credentials = credentials ,
937
+ )
938
+ else :
939
+ dataset_id = _create_default_bigquery_dataset_if_not_exists (
940
+ project = project , location = location , credentials = credentials
941
+ )
942
+ target_table_id = _generate_target_table_id (dataset_id )
859
943
860
944
gcs_uri_prefix = "gs://"
861
945
if gcs_uri .startswith (gcs_uri_prefix ):
@@ -877,13 +961,21 @@ def from_gemini_request_jsonl(
877
961
lines = [line .strip () for line in jsonl_string .splitlines () if line .strip ()]
878
962
df = pandas .DataFrame (lines , columns = [request_column_name ])
879
963
880
- temp_bigframes_df = bigframes .pandas .read_pandas (df )
881
- temp_bigframes_df [request_column_name ] = bigframes .bigquery .parse_json (
882
- temp_bigframes_df [request_column_name ]
964
+ session_options = bigframes .BigQueryOptions (
965
+ credentials = credentials ,
966
+ project = project ,
967
+ location = location ,
883
968
)
884
- temp_bigframes_df .to_gbq (
885
- destination_table = target_table_id ,
886
- if_exists = "replace" ,
969
+ with bigframes .connect (session_options ) as session :
970
+ temp_bigframes_df = session .read_pandas (df )
971
+ temp_bigframes_df [request_column_name ] = bigframes .bigquery .parse_json (
972
+ temp_bigframes_df [request_column_name ]
973
+ )
974
+ temp_table_id = temp_bigframes_df .to_gbq ()
975
+ client = bigquery .Client (project = project , credentials = credentials )
976
+ client .copy_table (
977
+ sources = temp_table_id ,
978
+ destination = target_table_id ,
887
979
)
888
980
889
981
bigquery_uri = f"bq://{ target_table_id } "
0 commit comments