Skip to content

Commit a323679

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: support creating multimodal datasets from a JSONL file containing Gemini requests
PiperOrigin-RevId: 742286476
1 parent 50fbdee commit a323679

File tree

2 files changed

+410
-24
lines changed

2 files changed

+410
-24
lines changed

google/cloud/aiplatform/preview/datasets.py

+201-22
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Dict, List, Optional, Tuple
2020

2121
from google.auth import credentials as auth_credentials
22+
from google.cloud import storage
2223
from google.cloud.aiplatform import base
2324
from google.cloud.aiplatform import compat
2425
from google.cloud.aiplatform import initializer
@@ -47,6 +48,7 @@
4748
_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD = "geminiTemplateConfigSource"
4849
_GEMINI_TEMPLATE_CONFIG_FIELD = "geminiTemplateConfig"
4950
_PROMPT_URI_FIELD = "promptUri"
51+
_REQUEST_COLUMN_NAME_FIELD = "requestColumnName"
5052

5153
_LOGGER = base.Logger(__name__)
5254

@@ -56,6 +58,7 @@ def _try_import_bigframes():
5658
try:
5759
import bigframes
5860
import bigframes.pandas
61+
import bigframes.bigquery
5962

6063
return bigframes
6164
except ImportError as exc:
@@ -69,9 +72,19 @@ def _get_metadata_for_bq(
6972
bq_uri: str,
7073
template_config: Optional[gca_dataset_service.GeminiTemplateConfig] = None,
7174
prompt_uri: Optional[str] = None,
75+
request_column_name: Optional[str] = None,
7276
) -> struct_pb2.Value:
73-
if template_config and prompt_uri:
74-
raise ValueError("Only one of template_config and prompt_uri can be specified.")
77+
if (
78+
sum(
79+
1
80+
for param in (template_config, prompt_uri, request_column_name)
81+
if param is not None
82+
)
83+
> 1
84+
):
85+
raise ValueError(
86+
"Only one of template_config, prompt_uri, request_column_name can be specified."
87+
)
7588

7689
input_config = {_INPUT_CONFIG_FIELD: {_BIGQUERY_SOURCE_FIELD: {_URI_FIELD: bq_uri}}}
7790
if template_config is not None:
@@ -85,6 +98,10 @@ def _get_metadata_for_bq(
8598
input_config[_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD] = {
8699
_PROMPT_URI_FIELD: prompt_uri
87100
}
101+
if request_column_name is not None:
102+
input_config[_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD] = {
103+
_REQUEST_COLUMN_NAME_FIELD: request_column_name
104+
}
88105
return json_format.ParseDict(input_config, struct_pb2.Value())
89106

90107

@@ -462,6 +479,7 @@ class MultimodalDataset(base.VertexAiResourceNounWithFutureManager):
462479
_delete_method = "delete_dataset"
463480
_parse_resource_name_method = "parse_dataset_path"
464481
_format_resource_name_method = "dataset_path"
482+
_DEFAULT_REQUEST_COLUMN_NAME = "requests"
465483

466484
def __init__(
467485
self,
@@ -577,6 +595,7 @@ def from_bigquery(
577595
"""
578596
return cls._create_from_bigquery(
579597
bigquery_uri=bigquery_uri,
598+
metadata=_get_metadata_for_bq(bq_uri=bigquery_uri),
580599
display_name=display_name,
581600
project=project,
582601
location=location,
@@ -663,8 +682,10 @@ def from_pandas(
663682
destination_table=target_table_id,
664683
if_exists="replace",
665684
)
685+
bigquery_uri = f"bq://{target_table_id}"
666686
return cls._create_from_bigquery(
667-
bigquery_uri=f"bq://{target_table_id}",
687+
bigquery_uri=bigquery_uri,
688+
metadata=_get_metadata_for_bq(bq_uri=bigquery_uri),
668689
display_name=display_name,
669690
project=project,
670691
location=location,
@@ -748,8 +769,129 @@ def from_bigframes(
748769
destination_table=target_table_id,
749770
if_exists="replace",
750771
)
772+
bigquery_uri = f"bq://{target_table_id}"
773+
return cls._create_from_bigquery(
774+
bigquery_uri=bigquery_uri,
775+
metadata=_get_metadata_for_bq(bq_uri=bigquery_uri),
776+
display_name=display_name,
777+
project=project,
778+
location=location,
779+
credentials=credentials,
780+
labels=labels,
781+
sync=sync,
782+
create_request_timeout=create_request_timeout,
783+
)
784+
785+
@classmethod
786+
def from_gemini_request_jsonl(
787+
cls,
788+
*,
789+
gcs_uri: str,
790+
target_table_id: str,
791+
display_name: Optional[str] = None,
792+
project: Optional[str] = None,
793+
location: Optional[str] = None,
794+
credentials: Optional[auth_credentials.Credentials] = None,
795+
labels: Optional[Dict[str, str]] = None,
796+
sync: bool = True,
797+
create_request_timeout: Optional[float] = None,
798+
) -> "MultimodalDataset":
799+
"""Creates a multimodal dataset from a JSONL file stored on GCS.
800+
801+
The JSONL file should contain a instances of Gemini
802+
`GenerateContentRequest` on each line. The data will be stored in a
803+
BigQuery table with a single column called "requests". The
804+
request_column_name in the dataset metadata will be set to "requests".
805+
806+
Args:
807+
gcs_uri (str):
808+
The Google Cloud Storage URI of the JSONL file to import.
809+
For example, 'gs://my-bucket/path/to/data.jsonl'
810+
target_table_id (str):
811+
The BigQuery table id where the dataframe will be uploaded. The
812+
table id can be in the format of "dataset.table" or
813+
"project.dataset.table". If a table already exists with the
814+
given table id, it will be overwritten. Note that the BigQuery
815+
dataset must already exist.
816+
display_name (str):
817+
Optional. The user-defined name of the dataset. The name can be
818+
up to 128 characters long and can consist of any UTF-8
819+
characters.
820+
project (str):
821+
Optional. Project to create this dataset in. Overrides project
822+
set in aiplatform.init.
823+
location (str):
824+
Optional. Location to create this dataset in. Overrides location
825+
set in aiplatform.init.
826+
credentials (auth_credentials.Credentials):
827+
Optional. Custom credentials to use to create this dataset.
828+
Overrides credentials set in aiplatform.init.
829+
labels (Dict[str, str]):
830+
Optional. The labels with user-defined metadata to organize your
831+
datasets. Label keys and values can be no longer than 64
832+
characters (Unicode codepoints), can only contain lowercase
833+
letters, numeric characters, underscores and dashes.
834+
International characters are allowed. See https://goo.gl/xmQnxf
835+
for more information on and examples of labels. No more than 64
836+
user labels can be associated with one dataset (System labels
837+
are excluded). System reserved label keys are prefixed with
838+
"aiplatform.googleapis.com/" and are immutable.
839+
sync (bool):
840+
Optional. Whether to execute this method synchronously. If
841+
False, this method will be executed in concurrent Future and any
842+
downstream object will be immediately returned and synced when
843+
the Future has completed.
844+
create_request_timeout (float):
845+
Optional. The timeout for the dataset creation request.
846+
847+
Returns:
848+
The created multimodal dataset.
849+
"""
850+
bigframes = _try_import_bigframes()
851+
if not project:
852+
project = initializer.global_config.project
853+
# TODO(b/400355374): `table_id` should be optional, and if not provided,
854+
# we generate a random table id. Also, check if we can use a default
855+
# dataset that's created from the SDK.
856+
target_table_id = _normalize_and_validate_table_id(
857+
table_id=target_table_id, project=project
858+
)
859+
860+
gcs_uri_prefix = "gs://"
861+
if gcs_uri.startswith(gcs_uri_prefix):
862+
gcs_uri = gcs_uri[len(gcs_uri_prefix) :]
863+
parts = gcs_uri.split("/", 1)
864+
if len(parts) != 2:
865+
raise ValueError(
866+
"Invalid GCS URI format. Expected: gs://bucket-name/object-path"
867+
)
868+
bucket_name = parts[0]
869+
blob_name = parts[1]
870+
871+
storage_client = storage.Client(project=project)
872+
bucket = storage_client.bucket(bucket_name)
873+
blob = bucket.blob(blob_name)
874+
request_column_name = cls._DEFAULT_REQUEST_COLUMN_NAME
875+
876+
jsonl_string = blob.download_as_text()
877+
lines = [line.strip() for line in jsonl_string.splitlines() if line.strip()]
878+
df = pandas.DataFrame(lines, columns=[request_column_name])
879+
880+
temp_bigframes_df = bigframes.pandas.read_pandas(df)
881+
temp_bigframes_df[request_column_name] = bigframes.bigquery.parse_json(
882+
temp_bigframes_df[request_column_name]
883+
)
884+
temp_bigframes_df.to_gbq(
885+
destination_table=target_table_id,
886+
if_exists="replace",
887+
)
888+
889+
bigquery_uri = f"bq://{target_table_id}"
751890
return cls._create_from_bigquery(
752-
bigquery_uri=f"bq://{target_table_id}",
891+
bigquery_uri=bigquery_uri,
892+
metadata=_get_metadata_for_bq(
893+
bq_uri=bigquery_uri, request_column_name=request_column_name
894+
),
753895
display_name=display_name,
754896
project=project,
755897
location=location,
@@ -765,6 +907,7 @@ def _create_from_bigquery(
765907
cls,
766908
*,
767909
bigquery_uri: str,
910+
metadata: struct_pb2.Value,
768911
display_name: Optional[str] = None,
769912
project: Optional[str] = None,
770913
location: Optional[str] = None,
@@ -788,7 +931,7 @@ def _create_from_bigquery(
788931
dataset = gca_dataset.Dataset(
789932
display_name=display_name,
790933
metadata_schema_uri=_MULTIMODAL_METADATA_SCHEMA_URI,
791-
metadata=_get_metadata_for_bq(bq_uri=bigquery_uri),
934+
metadata=metadata,
792935
labels=labels,
793936
)
794937
parent = initializer.global_config.common_location_path(
@@ -976,6 +1119,27 @@ def template_config(self) -> Optional[GeminiTemplateConfig]:
9761119

9771120
return None
9781121

1122+
@property
1123+
def request_column_name(self) -> Optional[str]:
1124+
"""Return the request column name if it is set in the dataset metadata.
1125+
1126+
The request column name specifies a column in the dataset that contains
1127+
assembled Gemini `GenerateContentRequest` instances.
1128+
"""
1129+
1130+
self._assert_gca_resource_is_available()
1131+
# Dataset has no attached template.
1132+
if _GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD not in self._gca_resource.metadata:
1133+
return None
1134+
if (
1135+
_REQUEST_COLUMN_NAME_FIELD
1136+
not in self._gca_resource.metadata[_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD]
1137+
):
1138+
return None
1139+
return self._gca_resource.metadata[_GEMINI_TEMPLATE_CONFIG_SOURCE_FIELD][
1140+
_REQUEST_COLUMN_NAME_FIELD
1141+
]
1142+
9791143
def assemble(
9801144
self,
9811145
*,
@@ -1003,12 +1167,15 @@ def assemble(
10031167
load_dataframe is True, otherwise None.
10041168
"""
10051169
bigframes = _try_import_bigframes()
1006-
template_config_to_use = _resolve_template_config(self, template_config)
1170+
request = gca_dataset_service.AssembleDataRequest(name=self.resource_name)
1171+
if self.request_column_name is not None:
1172+
request.request_column_name = self.request_column_name
1173+
else:
1174+
template_config_to_use = _resolve_template_config(self, template_config)
1175+
request.gemini_template_config = (
1176+
template_config_to_use._raw_gemini_template_config
1177+
)
10071178

1008-
request = gca_dataset_service.AssembleDataRequest(
1009-
name=self.resource_name,
1010-
gemini_template_config=template_config_to_use._raw_gemini_template_config,
1011-
)
10121179
assemble_lro = self.api_client.assemble_data(
10131180
request=request, timeout=assemble_request_timeout
10141181
)
@@ -1051,14 +1218,13 @@ def assess_tuning_resources(
10511218
dataset.
10521219
10531220
"""
1054-
template_config_to_use = _resolve_template_config(self, template_config)
1055-
request = gca_dataset_service.AssessDataRequest(
1056-
name=self.resource_name,
1057-
tuning_resource_usage_assessment_config=gca_dataset_service.AssessDataRequest.TuningResourceUsageAssessmentConfig(
1221+
request = _build_assess_data_request(self, template_config)
1222+
request.tuning_resource_usage_assessment_config = (
1223+
gca_dataset_service.AssessDataRequest.TuningResourceUsageAssessmentConfig(
10581224
model_name=model_name
1059-
),
1060-
gemini_template_config=template_config_to_use._raw_gemini_template_config,
1225+
)
10611226
)
1227+
10621228
assessment_result = (
10631229
self.api_client.assess_data(request=request, timeout=assess_request_timeout)
10641230
.result(timeout=None)
@@ -1116,14 +1282,12 @@ def assess_tuning_validity(
11161282
if dataset_usage_enum == DatasetUsage.DATASET_USAGE_UNSPECIFIED:
11171283
raise ValueError("Dataset usage must be specified.")
11181284

1119-
template_config_to_use = _resolve_template_config(self, template_config)
1120-
request = gca_dataset_service.AssessDataRequest(
1121-
name=self.resource_name,
1122-
tuning_validation_assessment_config=gca_dataset_service.AssessDataRequest.TuningValidationAssessmentConfig(
1285+
request = _build_assess_data_request(self, template_config)
1286+
request.tuning_validation_assessment_config = (
1287+
gca_dataset_service.AssessDataRequest.TuningValidationAssessmentConfig(
11231288
model_name=model_name,
11241289
dataset_usage=dataset_usage_enum,
1125-
),
1126-
gemini_template_config=template_config_to_use._raw_gemini_template_config,
1290+
)
11271291
)
11281292
assess_lro = self.api_client.assess_data(
11291293
request=request, timeout=assess_request_timeout
@@ -1147,3 +1311,18 @@ def _resolve_template_config(
11471311
return dataset.template_config
11481312
else:
11491313
raise ValueError("No template config was passed or attached to the dataset.")
1314+
1315+
1316+
def _build_assess_data_request(
1317+
dataset: MultimodalDataset,
1318+
template_config: Optional[GeminiTemplateConfig] = None,
1319+
):
1320+
request = gca_dataset_service.AssessDataRequest(name=dataset.resource_name)
1321+
if dataset.request_column_name is not None:
1322+
request.request_column_name = dataset.request_column_name
1323+
else:
1324+
template_config_to_use = _resolve_template_config(dataset, template_config)
1325+
request.gemini_template_config = (
1326+
template_config_to_use._raw_gemini_template_config
1327+
)
1328+
return request

0 commit comments

Comments
 (0)