Skip to content

Commit ec3ea30

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support of newly added fields of ExportData API to SDK
PiperOrigin-RevId: 595178001
1 parent 4d98c55 commit ec3ea30

File tree

2 files changed

+245
-18
lines changed

2 files changed

+245
-18
lines changed

google/cloud/aiplatform/datasets/dataset.py

+218-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import Dict, List, Optional, Sequence, Tuple, Union
18+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
1919

2020
from google.api_core import operation
2121
from google.auth import credentials as auth_credentials
@@ -27,11 +27,13 @@
2727
from google.cloud.aiplatform.compat.services import dataset_service_client
2828
from google.cloud.aiplatform.compat.types import (
2929
dataset as gca_dataset,
30+
dataset_service as gca_dataset_service,
3031
encryption_spec as gca_encryption_spec,
3132
io as gca_io,
3233
)
3334
from google.cloud.aiplatform.datasets import _datasources
3435
from google.protobuf import field_mask_pb2
36+
from google.protobuf import json_format
3537

3638
_LOGGER = base.Logger(__name__)
3739

@@ -561,6 +563,120 @@ def import_data(
561563
)
562564
return self
563565

566+
def _validate_and_convert_export_split(
567+
self,
568+
split: Union[Dict[str, str], Dict[str, float]],
569+
) -> Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]:
570+
"""
571+
Validates the split for data export. Valid splits are dicts
572+
encoding the contents of proto messages ExportFilterSplit or
573+
ExportFractionSplit. If the split is valid, this function returns
574+
the corresponding convertered proto message.
575+
576+
split (Union[Dict[str, str], Dict[str, float]]):
577+
The instructions how the export data should be split between the
578+
training, validation and test sets.
579+
"""
580+
if len(split) != 3:
581+
raise ValueError(
582+
"The provided split for data export does not provide enough"
583+
"information. It must have three fields, mapping to training,"
584+
"validation and test splits respectively."
585+
)
586+
587+
if not ("training_filter" in split or "training_fraction" in split):
588+
raise ValueError(
589+
"The provided filter for data export does not provide enough"
590+
"information. It must have three fields, mapping to training,"
591+
"validation and test respectively."
592+
)
593+
594+
if "training_filter" in split:
595+
if (
596+
"validation_filter" in split
597+
and "test_filter" in split
598+
and split["training_filter"] is str
599+
and split["validation_filter"] is str
600+
and split["test_filter"] is str
601+
):
602+
return gca_dataset.ExportFilterSplit(
603+
training_filter=split["training_filter"],
604+
validation_filter=split["validation_filter"],
605+
test_filter=split["test_filter"],
606+
)
607+
else:
608+
raise ValueError(
609+
"The provided ExportFilterSplit does not contain all"
610+
"three required fields: training_filter, "
611+
"validation_filter and test_filter."
612+
)
613+
else:
614+
if (
615+
"validation_fraction" in split
616+
and "test_fraction" in split
617+
and split["training_fraction"] is float
618+
and split["validation_fraction"] is float
619+
and split["test_fraction"] is float
620+
):
621+
return gca_dataset.ExportFractionSplit(
622+
training_fraction=split["training_fraction"],
623+
validation_fraction=split["validation_fraction"],
624+
test_fraction=split["test_fraction"],
625+
)
626+
else:
627+
raise ValueError(
628+
"The provided ExportFractionSplit does not contain all"
629+
"three required fields: training_fraction, "
630+
"validation_fraction and test_fraction."
631+
)
632+
633+
def _get_completed_export_data_operation(
634+
self,
635+
output_dir: str,
636+
export_use: Optional[gca_dataset.ExportDataConfig.ExportUse] = None,
637+
annotation_filter: Optional[str] = None,
638+
saved_query_id: Optional[str] = None,
639+
annotation_schema_uri: Optional[str] = None,
640+
split: Optional[
641+
Union[gca_dataset.ExportFilterSplit, gca_dataset.ExportFractionSplit]
642+
] = None,
643+
) -> gca_dataset_service.ExportDataResponse:
644+
self.wait()
645+
646+
# TODO(b/171311614): Add support for BigQuery export path
647+
export_data_config = gca_dataset.ExportDataConfig(
648+
gcs_destination=gca_io.GcsDestination(output_uri_prefix=output_dir)
649+
)
650+
if export_use is not None:
651+
export_data_config.export_use = export_use
652+
if annotation_filter is not None:
653+
export_data_config.annotation_filter = annotation_filter
654+
if saved_query_id is not None:
655+
export_data_config.saved_query_id = saved_query_id
656+
if annotation_schema_uri is not None:
657+
export_data_config.annotation_schema_uri = annotation_schema_uri
658+
if split is not None:
659+
if isinstance(split, gca_dataset.ExportFilterSplit):
660+
export_data_config.filter_split = split
661+
elif isinstance(split, gca_dataset.ExportFractionSplit):
662+
export_data_config.fraction_split = split
663+
664+
_LOGGER.log_action_start_against_resource("Exporting", "data", self)
665+
666+
export_lro = self.api_client.export_data(
667+
name=self.resource_name, export_config=export_data_config
668+
)
669+
670+
_LOGGER.log_action_started_against_resource_with_lro(
671+
"Export", "data", self.__class__, export_lro
672+
)
673+
674+
export_data_response = export_lro.result()
675+
676+
_LOGGER.log_action_completed_against_resource("data", "export", self)
677+
678+
return export_data_response
679+
564680
# TODO(b/174751568) add optional sync support
565681
def export_data(self, output_dir: str) -> Sequence[str]:
566682
"""Exports data to output dir to GCS.
@@ -585,29 +701,113 @@ def export_data(self, output_dir: str) -> Sequence[str]:
585701
exported_files (Sequence[str]):
586702
All of the files that are exported in this export operation.
587703
"""
588-
self.wait()
704+
return self._get_completed_export_data_operation(output_dir).exported_files
589705

590-
# TODO(b/171311614): Add support for BigQuery export path
591-
export_data_config = gca_dataset.ExportDataConfig(
592-
gcs_destination=gca_io.GcsDestination(output_uri_prefix=output_dir)
593-
)
706+
def export_data_for_custom_training(
707+
self,
708+
output_dir: str,
709+
annotation_filter: Optional[str] = None,
710+
saved_query_id: Optional[str] = None,
711+
annotation_schema_uri: Optional[str] = None,
712+
split: Optional[Union[Dict[str, str], Dict[str, float]]] = None,
713+
) -> Dict[str, Any]:
714+
"""Exports data to output dir to GCS for custom training use case.
715+
716+
Example annotation_schema_uri (image classification):
717+
gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml
718+
719+
Example split (filter split):
720+
{
721+
"training_filter": "labels.aiplatform.googleapis.com/ml_use=training",
722+
"validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation",
723+
"test_filter": "labels.aiplatform.googleapis.com/ml_use=test",
724+
}
725+
Example split (fraction split):
726+
{
727+
"training_fraction": 0.7,
728+
"validation_fraction": 0.2,
729+
"test_fraction": 0.1,
730+
}
594731
595-
_LOGGER.log_action_start_against_resource("Exporting", "data", self)
732+
Args:
733+
output_dir (str):
734+
Required. The Google Cloud Storage location where the output is to
735+
be written to. In the given directory a new directory will be
736+
created with name:
737+
``export-data-<dataset-display-name>-<timestamp-of-export-call>``
738+
where timestamp is in YYYYMMDDHHMMSS format. All export
739+
output will be written into that directory. Inside that
740+
directory, annotations with the same schema will be grouped
741+
into sub directories which are named with the corresponding
742+
annotations' schema title. Inside these sub directories, a
743+
schema.yaml will be created to describe the output format.
596744
597-
export_lro = self.api_client.export_data(
598-
name=self.resource_name, export_config=export_data_config
599-
)
745+
If the uri doesn't end with '/', a '/' will be automatically
746+
appended. The directory is created if it doesn't exist.
747+
annotation_filter (str):
748+
Optional. An expression for filtering what part of the Dataset
749+
is to be exported.
750+
Only Annotations that match this filter will be exported.
751+
The filter syntax is the same as in
752+
[ListAnnotations][DatasetService.ListAnnotations].
753+
saved_query_id (str):
754+
Optional. The ID of a SavedQuery (annotation set) under this
755+
Dataset used for filtering Annotations for training.
756+
757+
Only used for custom training data export use cases.
758+
Only applicable to Datasets that have SavedQueries.
759+
760+
Only Annotations that are associated with this SavedQuery are
761+
used in respectively training. When used in conjunction with
762+
annotations_filter, the Annotations used for training are
763+
filtered by both saved_query_id and annotations_filter.
764+
765+
Only one of saved_query_id and annotation_schema_uri should be
766+
specified as both of them represent the same thing: problem
767+
type.
768+
annotation_schema_uri (str):
769+
Optional. The Cloud Storage URI that points to a YAML file
770+
describing the annotation schema. The schema is defined as an
771+
OpenAPI 3.0.2 Schema Object. The schema files that can be used
772+
here are found in
773+
gs://google-cloud-aiplatform/schema/dataset/annotation/, note
774+
that the chosen schema must be consistent with
775+
metadata_schema_uri of this Dataset.
776+
777+
Only used for custom training data export use cases.
778+
Only applicable if this Dataset that have DataItems and
779+
Annotations.
780+
781+
Only Annotations that both match this schema and belong to
782+
DataItems not ignored by the split method are used in
783+
respectively training, validation or test role, depending on the
784+
role of the DataItem they are on.
785+
786+
When used in conjunction with annotations_filter, the
787+
Annotations used for training are filtered by both
788+
annotations_filter and annotation_schema_uri.
789+
split (Union[Dict[str, str], Dict[str, float]]):
790+
The instructions how the export data should be split between the
791+
training, validation and test sets.
600792
601-
_LOGGER.log_action_started_against_resource_with_lro(
602-
"Export", "data", self.__class__, export_lro
793+
Returns:
794+
export_data_response (Dict):
795+
Response message for DatasetService.ExportData in Dictionary
796+
format.
797+
"""
798+
split = self._validate_and_convert_export_split(split)
799+
800+
return json_format.MessageToDict(
801+
self._get_completed_export_data_operation(
802+
output_dir,
803+
gca_dataset.ExportDataConfig.ExportUse.CUSTOM_CODE_TRAINING,
804+
annotation_filter,
805+
saved_query_id,
806+
annotation_schema_uri,
807+
split,
808+
)
603809
)
604810

605-
export_data_response = export_lro.result()
606-
607-
_LOGGER.log_action_completed_against_resource("data", "export", self)
608-
609-
return export_data_response.exported_files
610-
611811
def update(
612812
self,
613813
*,

tests/system/aiplatform/test_dataset.py

+27
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,33 @@ def test_export_data(self, storage_client, staging_bucket):
382382

383383
assert blob # Verify the returned GCS export path exists
384384

385+
def test_export_data_for_custom_training(self, staging_bucket):
386+
"""Get an existing dataset, export data to a newly created folder in
387+
Google Cloud Storage, then verify data was successfully exported."""
388+
389+
# pylint: disable=protected-access
390+
# Custom training data export should be generic, hence using the base
391+
# _Dataset class here in test. In practice, users shuold be able to
392+
# use this function in any inhericted classes of _Dataset.
393+
dataset = aiplatform._Dataset(dataset_name=_TEST_TEXT_DATASET_ID)
394+
395+
split = {
396+
"training_fraction": 0.6,
397+
"validation_fraction": 0.2,
398+
"test_fraction": 0.2,
399+
}
400+
401+
export_data_response = dataset.export_data_for_custom_training(
402+
output_dir=f"gs://{staging_bucket.name}",
403+
annotation_schema_uri="gs://google-cloud-aiplatform/schema/dataset/annotation/text_classification_1.0.0.yaml",
404+
split=split,
405+
)
406+
407+
# Ensure three output paths (training, validation and test) are provided
408+
assert len(export_data_response["exported_files"]) == 3
409+
# Ensure data stats are calculated and present
410+
assert export_data_response["data_stats"]["training_data_items_count"] > 0
411+
385412
def test_update_dataset(self):
386413
"""Create a new dataset and use update() method to change its display_name, labels, and description.
387414
Then confirm these fields of the dataset was successfully modifed."""

0 commit comments

Comments
 (0)