19
19
from typing import Dict , List , Optional , Tuple
20
20
21
21
from google .auth import credentials as auth_credentials
22
+ from google .cloud import storage
22
23
from google .cloud .aiplatform import base
23
24
from google .cloud .aiplatform import compat
24
25
from google .cloud .aiplatform import initializer
47
48
_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD = "geminiTemplateConfigSource"
48
49
_GEMINI_TEMPLATE_CONFIG_FIELD = "geminiTemplateConfig"
49
50
_PROMPT_URI_FIELD = "promptUri"
51
+ _REQUEST_COLUMN_NAME_FIELD = "requestColumnName"
50
52
51
53
_LOGGER = base .Logger (__name__ )
52
54
@@ -56,6 +58,7 @@ def _try_import_bigframes():
56
58
try :
57
59
import bigframes
58
60
import bigframes .pandas
61
+ import bigframes .bigquery
59
62
60
63
return bigframes
61
64
except ImportError as exc :
@@ -69,9 +72,19 @@ def _get_metadata_for_bq(
69
72
bq_uri : str ,
70
73
template_config : Optional [gca_dataset_service .GeminiTemplateConfig ] = None ,
71
74
prompt_uri : Optional [str ] = None ,
75
+ request_column_name : Optional [str ] = None ,
72
76
) -> struct_pb2 .Value :
73
- if template_config and prompt_uri :
74
- raise ValueError ("Only one of template_config and prompt_uri can be specified." )
77
+ if (
78
+ sum (
79
+ 1
80
+ for param in (template_config , prompt_uri , request_column_name )
81
+ if param is not None
82
+ )
83
+ > 1
84
+ ):
85
+ raise ValueError (
86
+ "Only one of template_config, prompt_uri, request_column_name can be specified."
87
+ )
75
88
76
89
input_config = {_INPUT_CONFIG_FIELD : {_BIGQUERY_SOURCE_FIELD : {_URI_FIELD : bq_uri }}}
77
90
if template_config is not None :
@@ -85,6 +98,10 @@ def _get_metadata_for_bq(
85
98
input_config [_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD ] = {
86
99
_PROMPT_URI_FIELD : prompt_uri
87
100
}
101
+ if request_column_name is not None :
102
+ input_config [_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD ] = {
103
+ _REQUEST_COLUMN_NAME_FIELD : request_column_name
104
+ }
88
105
return json_format .ParseDict (input_config , struct_pb2 .Value ())
89
106
90
107
@@ -462,6 +479,7 @@ class MultimodalDataset(base.VertexAiResourceNounWithFutureManager):
462
479
_delete_method = "delete_dataset"
463
480
_parse_resource_name_method = "parse_dataset_path"
464
481
_format_resource_name_method = "dataset_path"
482
+ _DEFAULT_REQUEST_COLUMN_NAME = "requests"
465
483
466
484
def __init__ (
467
485
self ,
@@ -577,6 +595,7 @@ def from_bigquery(
577
595
"""
578
596
return cls ._create_from_bigquery (
579
597
bigquery_uri = bigquery_uri ,
598
+ metadata = _get_metadata_for_bq (bq_uri = bigquery_uri ),
580
599
display_name = display_name ,
581
600
project = project ,
582
601
location = location ,
@@ -663,8 +682,10 @@ def from_pandas(
663
682
destination_table = target_table_id ,
664
683
if_exists = "replace" ,
665
684
)
685
+ bigquery_uri = f"bq://{ target_table_id } "
666
686
return cls ._create_from_bigquery (
667
- bigquery_uri = f"bq://{ target_table_id } " ,
687
+ bigquery_uri = bigquery_uri ,
688
+ metadata = _get_metadata_for_bq (bq_uri = bigquery_uri ),
668
689
display_name = display_name ,
669
690
project = project ,
670
691
location = location ,
@@ -748,8 +769,129 @@ def from_bigframes(
748
769
destination_table = target_table_id ,
749
770
if_exists = "replace" ,
750
771
)
772
+ bigquery_uri = f"bq://{ target_table_id } "
773
+ return cls ._create_from_bigquery (
774
+ bigquery_uri = bigquery_uri ,
775
+ metadata = _get_metadata_for_bq (bq_uri = bigquery_uri ),
776
+ display_name = display_name ,
777
+ project = project ,
778
+ location = location ,
779
+ credentials = credentials ,
780
+ labels = labels ,
781
+ sync = sync ,
782
+ create_request_timeout = create_request_timeout ,
783
+ )
784
+
785
+ @classmethod
786
+ def from_gemini_request_jsonl (
787
+ cls ,
788
+ * ,
789
+ gcs_uri : str ,
790
+ target_table_id : str ,
791
+ display_name : Optional [str ] = None ,
792
+ project : Optional [str ] = None ,
793
+ location : Optional [str ] = None ,
794
+ credentials : Optional [auth_credentials .Credentials ] = None ,
795
+ labels : Optional [Dict [str , str ]] = None ,
796
+ sync : bool = True ,
797
+ create_request_timeout : Optional [float ] = None ,
798
+ ) -> "MultimodalDataset" :
799
+ """Creates a multimodal dataset from a JSONL file stored on GCS.
800
+
801
+ The JSONL file should contain a instances of Gemini
802
+ `GenerateContentRequest` on each line. The data will be stored in a
803
+ BigQuery table with a single column called "requests". The
804
+ request_column_name in the dataset metadata will be set to "requests".
805
+
806
+ Args:
807
+ gcs_uri (str):
808
+ The Google Cloud Storage URI of the JSONL file to import.
809
+ For example, 'gs://my-bucket/path/to/data.jsonl'
810
+ target_table_id (str):
811
+ The BigQuery table id where the dataframe will be uploaded. The
812
+ table id can be in the format of "dataset.table" or
813
+ "project.dataset.table". If a table already exists with the
814
+ given table id, it will be overwritten. Note that the BigQuery
815
+ dataset must already exist.
816
+ display_name (str):
817
+ Optional. The user-defined name of the dataset. The name can be
818
+ up to 128 characters long and can consist of any UTF-8
819
+ characters.
820
+ project (str):
821
+ Optional. Project to create this dataset in. Overrides project
822
+ set in aiplatform.init.
823
+ location (str):
824
+ Optional. Location to create this dataset in. Overrides location
825
+ set in aiplatform.init.
826
+ credentials (auth_credentials.Credentials):
827
+ Optional. Custom credentials to use to create this dataset.
828
+ Overrides credentials set in aiplatform.init.
829
+ labels (Dict[str, str]):
830
+ Optional. The labels with user-defined metadata to organize your
831
+ datasets. Label keys and values can be no longer than 64
832
+ characters (Unicode codepoints), can only contain lowercase
833
+ letters, numeric characters, underscores and dashes.
834
+ International characters are allowed. See https://goo.gl/xmQnxf
835
+ for more information on and examples of labels. No more than 64
836
+ user labels can be associated with one dataset (System labels
837
+ are excluded). System reserved label keys are prefixed with
838
+ "aiplatform.googleapis.com/" and are immutable.
839
+ sync (bool):
840
+ Optional. Whether to execute this method synchronously. If
841
+ False, this method will be executed in concurrent Future and any
842
+ downstream object will be immediately returned and synced when
843
+ the Future has completed.
844
+ create_request_timeout (float):
845
+ Optional. The timeout for the dataset creation request.
846
+
847
+ Returns:
848
+ The created multimodal dataset.
849
+ """
850
+ bigframes = _try_import_bigframes ()
851
+ if not project :
852
+ project = initializer .global_config .project
853
+ # TODO(b/400355374): `table_id` should be optional, and if not provided,
854
+ # we generate a random table id. Also, check if we can use a default
855
+ # dataset that's created from the SDK.
856
+ target_table_id = _normalize_and_validate_table_id (
857
+ table_id = target_table_id , project = project
858
+ )
859
+
860
+ gcs_uri_prefix = "gs://"
861
+ if gcs_uri .startswith (gcs_uri_prefix ):
862
+ gcs_uri = gcs_uri [len (gcs_uri_prefix ) :]
863
+ parts = gcs_uri .split ("/" , 1 )
864
+ if len (parts ) != 2 :
865
+ raise ValueError (
866
+ "Invalid GCS URI format. Expected: gs://bucket-name/object-path"
867
+ )
868
+ bucket_name = parts [0 ]
869
+ blob_name = parts [1 ]
870
+
871
+ storage_client = storage .Client (project = project )
872
+ bucket = storage_client .bucket (bucket_name )
873
+ blob = bucket .blob (blob_name )
874
+ request_column_name = cls ._DEFAULT_REQUEST_COLUMN_NAME
875
+
876
+ jsonl_string = blob .download_as_text ()
877
+ lines = [line .strip () for line in jsonl_string .splitlines () if line .strip ()]
878
+ df = pandas .DataFrame (lines , columns = [request_column_name ])
879
+
880
+ temp_bigframes_df = bigframes .pandas .read_pandas (df )
881
+ temp_bigframes_df [request_column_name ] = bigframes .bigquery .parse_json (
882
+ temp_bigframes_df [request_column_name ]
883
+ )
884
+ temp_bigframes_df .to_gbq (
885
+ destination_table = target_table_id ,
886
+ if_exists = "replace" ,
887
+ )
888
+
889
+ bigquery_uri = f"bq://{ target_table_id } "
751
890
return cls ._create_from_bigquery (
752
- bigquery_uri = f"bq://{ target_table_id } " ,
891
+ bigquery_uri = bigquery_uri ,
892
+ metadata = _get_metadata_for_bq (
893
+ bq_uri = bigquery_uri , request_column_name = request_column_name
894
+ ),
753
895
display_name = display_name ,
754
896
project = project ,
755
897
location = location ,
@@ -765,6 +907,7 @@ def _create_from_bigquery(
765
907
cls ,
766
908
* ,
767
909
bigquery_uri : str ,
910
+ metadata : struct_pb2 .Value ,
768
911
display_name : Optional [str ] = None ,
769
912
project : Optional [str ] = None ,
770
913
location : Optional [str ] = None ,
@@ -788,7 +931,7 @@ def _create_from_bigquery(
788
931
dataset = gca_dataset .Dataset (
789
932
display_name = display_name ,
790
933
metadata_schema_uri = _MULTIMODAL_METADATA_SCHEMA_URI ,
791
- metadata = _get_metadata_for_bq ( bq_uri = bigquery_uri ) ,
934
+ metadata = metadata ,
792
935
labels = labels ,
793
936
)
794
937
parent = initializer .global_config .common_location_path (
@@ -976,6 +1119,27 @@ def template_config(self) -> Optional[GeminiTemplateConfig]:
976
1119
977
1120
return None
978
1121
1122
+ @property
1123
+ def request_column_name (self ) -> Optional [str ]:
1124
+ """Return the request column name if it is set in the dataset metadata.
1125
+
1126
+ The request column name specifies a column in the dataset that contains
1127
+ assembled Gemini `GenerateContentRequest` instances.
1128
+ """
1129
+
1130
+ self ._assert_gca_resource_is_available ()
1131
+ # Dataset has no attached template.
1132
+ if _GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD not in self ._gca_resource .metadata :
1133
+ return None
1134
+ if (
1135
+ _REQUEST_COLUMN_NAME_FIELD
1136
+ not in self ._gca_resource .metadata [_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD ]
1137
+ ):
1138
+ return None
1139
+ return self ._gca_resource .metadata [_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD ][
1140
+ _REQUEST_COLUMN_NAME_FIELD
1141
+ ]
1142
+
979
1143
def assemble (
980
1144
self ,
981
1145
* ,
@@ -1003,12 +1167,15 @@ def assemble(
1003
1167
load_dataframe is True, otherwise None.
1004
1168
"""
1005
1169
bigframes = _try_import_bigframes ()
1006
- template_config_to_use = _resolve_template_config (self , template_config )
1170
+ request = gca_dataset_service .AssembleDataRequest (name = self .resource_name )
1171
+ if self .request_column_name is not None :
1172
+ request .request_column_name = self .request_column_name
1173
+ else :
1174
+ template_config_to_use = _resolve_template_config (self , template_config )
1175
+ request .gemini_template_config = (
1176
+ template_config_to_use ._raw_gemini_template_config
1177
+ )
1007
1178
1008
- request = gca_dataset_service .AssembleDataRequest (
1009
- name = self .resource_name ,
1010
- gemini_template_config = template_config_to_use ._raw_gemini_template_config ,
1011
- )
1012
1179
assemble_lro = self .api_client .assemble_data (
1013
1180
request = request , timeout = assemble_request_timeout
1014
1181
)
@@ -1051,14 +1218,13 @@ def assess_tuning_resources(
1051
1218
dataset.
1052
1219
1053
1220
"""
1054
- template_config_to_use = _resolve_template_config (self , template_config )
1055
- request = gca_dataset_service .AssessDataRequest (
1056
- name = self .resource_name ,
1057
- tuning_resource_usage_assessment_config = gca_dataset_service .AssessDataRequest .TuningResourceUsageAssessmentConfig (
1221
+ request = _build_assess_data_request (self , template_config )
1222
+ request .tuning_resource_usage_assessment_config = (
1223
+ gca_dataset_service .AssessDataRequest .TuningResourceUsageAssessmentConfig (
1058
1224
model_name = model_name
1059
- ),
1060
- gemini_template_config = template_config_to_use ._raw_gemini_template_config ,
1225
+ )
1061
1226
)
1227
+
1062
1228
assessment_result = (
1063
1229
self .api_client .assess_data (request = request , timeout = assess_request_timeout )
1064
1230
.result (timeout = None )
@@ -1116,14 +1282,12 @@ def assess_tuning_validity(
1116
1282
if dataset_usage_enum == DatasetUsage .DATASET_USAGE_UNSPECIFIED :
1117
1283
raise ValueError ("Dataset usage must be specified." )
1118
1284
1119
- template_config_to_use = _resolve_template_config (self , template_config )
1120
- request = gca_dataset_service .AssessDataRequest (
1121
- name = self .resource_name ,
1122
- tuning_validation_assessment_config = gca_dataset_service .AssessDataRequest .TuningValidationAssessmentConfig (
1285
+ request = _build_assess_data_request (self , template_config )
1286
+ request .tuning_validation_assessment_config = (
1287
+ gca_dataset_service .AssessDataRequest .TuningValidationAssessmentConfig (
1123
1288
model_name = model_name ,
1124
1289
dataset_usage = dataset_usage_enum ,
1125
- ),
1126
- gemini_template_config = template_config_to_use ._raw_gemini_template_config ,
1290
+ )
1127
1291
)
1128
1292
assess_lro = self .api_client .assess_data (
1129
1293
request = request , timeout = assess_request_timeout
@@ -1147,3 +1311,18 @@ def _resolve_template_config(
1147
1311
return dataset .template_config
1148
1312
else :
1149
1313
raise ValueError ("No template config was passed or attached to the dataset." )
1314
+
1315
+
1316
+ def _build_assess_data_request (
1317
+ dataset : MultimodalDataset ,
1318
+ template_config : Optional [GeminiTemplateConfig ] = None ,
1319
+ ):
1320
+ request = gca_dataset_service .AssessDataRequest (name = dataset .resource_name )
1321
+ if dataset .request_column_name is not None :
1322
+ request .request_column_name = dataset .request_column_name
1323
+ else :
1324
+ template_config_to_use = _resolve_template_config (dataset , template_config )
1325
+ request .gemini_template_config = (
1326
+ template_config_to_use ._raw_gemini_template_config
1327
+ )
1328
+ return request
0 commit comments