15
15
# limitations under the License.
16
16
#
17
17
18
- from typing import Dict , List , Optional , Sequence , Tuple , Union
18
+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
19
19
20
20
from google .api_core import operation
21
21
from google .auth import credentials as auth_credentials
27
27
from google .cloud .aiplatform .compat .services import dataset_service_client
28
28
from google .cloud .aiplatform .compat .types import (
29
29
dataset as gca_dataset ,
30
+ dataset_service as gca_dataset_service ,
30
31
encryption_spec as gca_encryption_spec ,
31
32
io as gca_io ,
32
33
)
33
34
from google .cloud .aiplatform .datasets import _datasources
34
35
from google .protobuf import field_mask_pb2
36
+ from google .protobuf import json_format
35
37
36
38
_LOGGER = base .Logger (__name__ )
37
39
@@ -561,6 +563,120 @@ def import_data(
561
563
)
562
564
return self
563
565
566
+ def _validate_and_convert_export_split (
567
+ self ,
568
+ split : Union [Dict [str , str ], Dict [str , float ]],
569
+ ) -> Union [gca_dataset .ExportFilterSplit , gca_dataset .ExportFractionSplit ]:
570
+ """
571
+ Validates the split for data export. Valid splits are dicts
572
+ encoding the contents of proto messages ExportFilterSplit or
573
+ ExportFractionSplit. If the split is valid, this function returns
574
+ the corresponding convertered proto message.
575
+
576
+ split (Union[Dict[str, str], Dict[str, float]]):
577
+ The instructions how the export data should be split between the
578
+ training, validation and test sets.
579
+ """
580
+ if len (split ) != 3 :
581
+ raise ValueError (
582
+ "The provided split for data export does not provide enough"
583
+ "information. It must have three fields, mapping to training,"
584
+ "validation and test splits respectively."
585
+ )
586
+
587
+ if not ("training_filter" in split or "training_fraction" in split ):
588
+ raise ValueError (
589
+ "The provided filter for data export does not provide enough"
590
+ "information. It must have three fields, mapping to training,"
591
+ "validation and test respectively."
592
+ )
593
+
594
+ if "training_filter" in split :
595
+ if (
596
+ "validation_filter" in split
597
+ and "test_filter" in split
598
+ and split ["training_filter" ] is str
599
+ and split ["validation_filter" ] is str
600
+ and split ["test_filter" ] is str
601
+ ):
602
+ return gca_dataset .ExportFilterSplit (
603
+ training_filter = split ["training_filter" ],
604
+ validation_filter = split ["validation_filter" ],
605
+ test_filter = split ["test_filter" ],
606
+ )
607
+ else :
608
+ raise ValueError (
609
+ "The provided ExportFilterSplit does not contain all"
610
+ "three required fields: training_filter, "
611
+ "validation_filter and test_filter."
612
+ )
613
+ else :
614
+ if (
615
+ "validation_fraction" in split
616
+ and "test_fraction" in split
617
+ and split ["training_fraction" ] is float
618
+ and split ["validation_fraction" ] is float
619
+ and split ["test_fraction" ] is float
620
+ ):
621
+ return gca_dataset .ExportFractionSplit (
622
+ training_fraction = split ["training_fraction" ],
623
+ validation_fraction = split ["validation_fraction" ],
624
+ test_fraction = split ["test_fraction" ],
625
+ )
626
+ else :
627
+ raise ValueError (
628
+ "The provided ExportFractionSplit does not contain all"
629
+ "three required fields: training_fraction, "
630
+ "validation_fraction and test_fraction."
631
+ )
632
+
633
+ def _get_completed_export_data_operation (
634
+ self ,
635
+ output_dir : str ,
636
+ export_use : Optional [gca_dataset .ExportDataConfig .ExportUse ] = None ,
637
+ annotation_filter : Optional [str ] = None ,
638
+ saved_query_id : Optional [str ] = None ,
639
+ annotation_schema_uri : Optional [str ] = None ,
640
+ split : Optional [
641
+ Union [gca_dataset .ExportFilterSplit , gca_dataset .ExportFractionSplit ]
642
+ ] = None ,
643
+ ) -> gca_dataset_service .ExportDataResponse :
644
+ self .wait ()
645
+
646
+ # TODO(b/171311614): Add support for BigQuery export path
647
+ export_data_config = gca_dataset .ExportDataConfig (
648
+ gcs_destination = gca_io .GcsDestination (output_uri_prefix = output_dir )
649
+ )
650
+ if export_use is not None :
651
+ export_data_config .export_use = export_use
652
+ if annotation_filter is not None :
653
+ export_data_config .annotation_filter = annotation_filter
654
+ if saved_query_id is not None :
655
+ export_data_config .saved_query_id = saved_query_id
656
+ if annotation_schema_uri is not None :
657
+ export_data_config .annotation_schema_uri = annotation_schema_uri
658
+ if split is not None :
659
+ if isinstance (split , gca_dataset .ExportFilterSplit ):
660
+ export_data_config .filter_split = split
661
+ elif isinstance (split , gca_dataset .ExportFractionSplit ):
662
+ export_data_config .fraction_split = split
663
+
664
+ _LOGGER .log_action_start_against_resource ("Exporting" , "data" , self )
665
+
666
+ export_lro = self .api_client .export_data (
667
+ name = self .resource_name , export_config = export_data_config
668
+ )
669
+
670
+ _LOGGER .log_action_started_against_resource_with_lro (
671
+ "Export" , "data" , self .__class__ , export_lro
672
+ )
673
+
674
+ export_data_response = export_lro .result ()
675
+
676
+ _LOGGER .log_action_completed_against_resource ("data" , "export" , self )
677
+
678
+ return export_data_response
679
+
564
680
# TODO(b/174751568) add optional sync support
565
681
def export_data (self , output_dir : str ) -> Sequence [str ]:
566
682
"""Exports data to output dir to GCS.
@@ -585,29 +701,113 @@ def export_data(self, output_dir: str) -> Sequence[str]:
585
701
exported_files (Sequence[str]):
586
702
All of the files that are exported in this export operation.
587
703
"""
588
- self .wait ()
704
+ return self ._get_completed_export_data_operation ( output_dir ). exported_files
589
705
590
- # TODO(b/171311614): Add support for BigQuery export path
591
- export_data_config = gca_dataset .ExportDataConfig (
592
- gcs_destination = gca_io .GcsDestination (output_uri_prefix = output_dir )
593
- )
706
+ def export_data_for_custom_training (
707
+ self ,
708
+ output_dir : str ,
709
+ annotation_filter : Optional [str ] = None ,
710
+ saved_query_id : Optional [str ] = None ,
711
+ annotation_schema_uri : Optional [str ] = None ,
712
+ split : Optional [Union [Dict [str , str ], Dict [str , float ]]] = None ,
713
+ ) -> Dict [str , Any ]:
714
+ """Exports data to output dir to GCS for custom training use case.
715
+
716
+ Example annotation_schema_uri (image classification):
717
+ gs://google-cloud-aiplatform/schema/dataset/annotation/image_classification_1.0.0.yaml
718
+
719
+ Example split (filter split):
720
+ {
721
+ "training_filter": "labels.aiplatform.googleapis.com/ml_use=training",
722
+ "validation_filter": "labels.aiplatform.googleapis.com/ml_use=validation",
723
+ "test_filter": "labels.aiplatform.googleapis.com/ml_use=test",
724
+ }
725
+ Example split (fraction split):
726
+ {
727
+ "training_fraction": 0.7,
728
+ "validation_fraction": 0.2,
729
+ "test_fraction": 0.1,
730
+ }
594
731
595
- _LOGGER .log_action_start_against_resource ("Exporting" , "data" , self )
732
+ Args:
733
+ output_dir (str):
734
+ Required. The Google Cloud Storage location where the output is to
735
+ be written to. In the given directory a new directory will be
736
+ created with name:
737
+ ``export-data-<dataset-display-name>-<timestamp-of-export-call>``
738
+ where timestamp is in YYYYMMDDHHMMSS format. All export
739
+ output will be written into that directory. Inside that
740
+ directory, annotations with the same schema will be grouped
741
+ into sub directories which are named with the corresponding
742
+ annotations' schema title. Inside these sub directories, a
743
+ schema.yaml will be created to describe the output format.
596
744
597
- export_lro = self .api_client .export_data (
598
- name = self .resource_name , export_config = export_data_config
599
- )
745
+ If the uri doesn't end with '/', a '/' will be automatically
746
+ appended. The directory is created if it doesn't exist.
747
+ annotation_filter (str):
748
+ Optional. An expression for filtering what part of the Dataset
749
+ is to be exported.
750
+ Only Annotations that match this filter will be exported.
751
+ The filter syntax is the same as in
752
+ [ListAnnotations][DatasetService.ListAnnotations].
753
+ saved_query_id (str):
754
+ Optional. The ID of a SavedQuery (annotation set) under this
755
+ Dataset used for filtering Annotations for training.
756
+
757
+ Only used for custom training data export use cases.
758
+ Only applicable to Datasets that have SavedQueries.
759
+
760
+ Only Annotations that are associated with this SavedQuery are
761
+ used in respectively training. When used in conjunction with
762
+ annotations_filter, the Annotations used for training are
763
+ filtered by both saved_query_id and annotations_filter.
764
+
765
+ Only one of saved_query_id and annotation_schema_uri should be
766
+ specified as both of them represent the same thing: problem
767
+ type.
768
+ annotation_schema_uri (str):
769
+ Optional. The Cloud Storage URI that points to a YAML file
770
+ describing the annotation schema. The schema is defined as an
771
+ OpenAPI 3.0.2 Schema Object. The schema files that can be used
772
+ here are found in
773
+ gs://google-cloud-aiplatform/schema/dataset/annotation/, note
774
+ that the chosen schema must be consistent with
775
+ metadata_schema_uri of this Dataset.
776
+
777
+ Only used for custom training data export use cases.
778
+ Only applicable if this Dataset that have DataItems and
779
+ Annotations.
780
+
781
+ Only Annotations that both match this schema and belong to
782
+ DataItems not ignored by the split method are used in
783
+ respectively training, validation or test role, depending on the
784
+ role of the DataItem they are on.
785
+
786
+ When used in conjunction with annotations_filter, the
787
+ Annotations used for training are filtered by both
788
+ annotations_filter and annotation_schema_uri.
789
+ split (Union[Dict[str, str], Dict[str, float]]):
790
+ The instructions how the export data should be split between the
791
+ training, validation and test sets.
600
792
601
- _LOGGER .log_action_started_against_resource_with_lro (
602
- "Export" , "data" , self .__class__ , export_lro
793
+ Returns:
794
+ export_data_response (Dict):
795
+ Response message for DatasetService.ExportData in Dictionary
796
+ format.
797
+ """
798
+ split = self ._validate_and_convert_export_split (split )
799
+
800
+ return json_format .MessageToDict (
801
+ self ._get_completed_export_data_operation (
802
+ output_dir ,
803
+ gca_dataset .ExportDataConfig .ExportUse .CUSTOM_CODE_TRAINING ,
804
+ annotation_filter ,
805
+ saved_query_id ,
806
+ annotation_schema_uri ,
807
+ split ,
808
+ )
603
809
)
604
810
605
- export_data_response = export_lro .result ()
606
-
607
- _LOGGER .log_action_completed_against_resource ("data" , "export" , self )
608
-
609
- return export_data_response .exported_files
610
-
611
811
def update (
612
812
self ,
613
813
* ,
0 commit comments