|
19 | 19 | import logging
|
20 | 20 | import time
|
21 | 21 | import re
|
22 |
| -from typing import Any, Dict, List, Optional, Union |
| 22 | +import tempfile |
| 23 | +from typing import Any, Callable, Dict, List, Optional, Union |
23 | 24 |
|
24 | 25 | from google.auth import credentials as auth_credentials
|
25 | 26 | from google.cloud import aiplatform
|
|
33 | 34 | from google.cloud.aiplatform.metadata import constants as metadata_constants
|
34 | 35 | from google.cloud.aiplatform.metadata import experiment_resources
|
35 | 36 | from google.cloud.aiplatform.metadata import utils as metadata_utils
|
| 37 | +from google.cloud.aiplatform.utils import gcs_utils |
36 | 38 | from google.cloud.aiplatform.utils import yaml_utils
|
37 | 39 | from google.cloud.aiplatform.utils import pipeline_utils
|
38 | 40 | from google.protobuf import json_format
|
@@ -131,7 +133,9 @@ def __init__(
|
131 | 133 | Optional. The unique ID of the job run.
|
132 | 134 | If not specified, pipeline name + timestamp will be used.
|
133 | 135 | pipeline_root (str):
|
134 |
| - Optional. The root of the pipeline outputs. Default to be staging bucket. |
| 136 | + Optional. The root of the pipeline outputs. If not set, the staging bucket |
| 137 | + set in aiplatform.init will be used. If that's not set a pipeline-specific |
| 138 | + artifacts bucket will be used. |
135 | 139 | parameter_values (Dict[str, Any]):
|
136 | 140 | Optional. The mapping from runtime parameter names to its values that
|
137 | 141 | control the pipeline run.
|
@@ -219,6 +223,13 @@ def __init__(
|
219 | 223 | or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
|
220 | 224 | or initializer.global_config.staging_bucket
|
221 | 225 | )
|
| 226 | + pipeline_root = ( |
| 227 | + pipeline_root |
| 228 | + or gcs_utils.generate_gcs_directory_for_pipeline_artifacts( |
| 229 | + project=project, |
| 230 | + location=location, |
| 231 | + ) |
| 232 | + ) |
222 | 233 | builder = pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
|
223 | 234 | pipeline_job
|
224 | 235 | )
|
@@ -332,6 +343,13 @@ def submit(
|
332 | 343 | if network:
|
333 | 344 | self._gca_resource.network = network
|
334 | 345 |
|
| 346 | + gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist( |
| 347 | + output_artifacts_gcs_dir=self.pipeline_spec.get("gcsOutputDirectory"), |
| 348 | + service_account=self._gca_resource.service_account, |
| 349 | + project=self.project, |
| 350 | + location=self.location, |
| 351 | + ) |
| 352 | + |
335 | 353 | # Prevents logs from being supressed on TFX pipelines
|
336 | 354 | if self._gca_resource.pipeline_spec.get("sdkVersion", "").startswith("tfx"):
|
337 | 355 | _LOGGER.setLevel(logging.INFO)
|
@@ -772,6 +790,125 @@ def clone(
|
772 | 790 |
|
773 | 791 | return cloned
|
774 | 792 |
|
| 793 | + @staticmethod |
| 794 | + def from_pipeline_func( |
| 795 | + # Parameters for the PipelineJob constructor |
| 796 | + pipeline_func: Callable, |
| 797 | + parameter_values: Optional[Dict[str, Any]] = None, |
| 798 | + output_artifacts_gcs_dir: Optional[str] = None, |
| 799 | + enable_caching: Optional[bool] = None, |
| 800 | + context_name: Optional[str] = "pipeline", |
| 801 | + display_name: Optional[str] = None, |
| 802 | + labels: Optional[Dict[str, str]] = None, |
| 803 | + job_id: Optional[str] = None, |
| 804 | + # Parameters for the Vertex SDK |
| 805 | + project: Optional[str] = None, |
| 806 | + location: Optional[str] = None, |
| 807 | + credentials: Optional[auth_credentials.Credentials] = None, |
| 808 | + encryption_spec_key_name: Optional[str] = None, |
| 809 | + ) -> "PipelineJob": |
| 810 | + """Creates PipelineJob by compiling a pipeline function. |
| 811 | +
|
| 812 | + Args: |
| 813 | + pipeline_func (Callable): |
| 814 | + Required. A pipeline function to compile. |
| 815 | + A pipeline function creates instances of components and connects |
| 816 | + component inputs to outputs. |
| 817 | + parameter_values (Dict[str, Any]): |
| 818 | + Optional. The mapping from runtime parameter names to its values that |
| 819 | + control the pipeline run. |
| 820 | + output_artifacts_gcs_dir (str): |
| 821 | + Optional. The GCS location of the pipeline outputs. |
| 822 | + A GCS bucket for artifacts will be created if not specified. |
| 823 | + enable_caching (bool): |
| 824 | + Optional. Whether to turn on caching for the run. |
| 825 | +
|
| 826 | + If this is not set, defaults to the compile time settings, which |
| 827 | + are True for all tasks by default, while users may specify |
| 828 | + different caching options for individual tasks. |
| 829 | +
|
| 830 | + If this is set, the setting applies to all tasks in the pipeline. |
| 831 | +
|
| 832 | + Overrides the compile time settings. |
| 833 | + context_name (str): |
| 834 | + Optional. The name of metadata context. Used for cached execution reuse. |
| 835 | + display_name (str): |
| 836 | + Optional. The user-defined name of this Pipeline. |
| 837 | + labels (Dict[str, str]): |
| 838 | + Optional. The user defined metadata to organize PipelineJob. |
| 839 | + job_id (str): |
| 840 | + Optional. The unique ID of the job run. |
| 841 | + If not specified, pipeline name + timestamp will be used. |
| 842 | +
|
| 843 | + project (str): |
| 844 | + Optional. The project that you want to run this PipelineJob in. If not set, |
| 845 | + the project set in aiplatform.init will be used. |
| 846 | + location (str): |
| 847 | + Optional. Location to create PipelineJob. If not set, |
| 848 | + location set in aiplatform.init will be used. |
| 849 | + credentials (auth_credentials.Credentials): |
| 850 | + Optional. Custom credentials to use to create this PipelineJob. |
| 851 | + Overrides credentials set in aiplatform.init. |
| 852 | + encryption_spec_key_name (str): |
| 853 | + Optional. The Cloud KMS resource identifier of the customer |
| 854 | + managed encryption key used to protect the job. Has the |
| 855 | + form: |
| 856 | + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. |
| 857 | + The key needs to be in the same region as where the compute |
| 858 | + resource is created. |
| 859 | +
|
| 860 | + If this is set, then all |
| 861 | + resources created by the PipelineJob will |
| 862 | + be encrypted with the provided encryption key. |
| 863 | +
|
| 864 | + Overrides encryption_spec_key_name set in aiplatform.init. |
| 865 | +
|
| 866 | + Returns: |
| 867 | + A Vertex AI PipelineJob. |
| 868 | +
|
| 869 | + Raises: |
| 870 | + ValueError: If job_id or labels have incorrect format. |
| 871 | + """ |
| 872 | + |
| 873 | + # Importing the KFP module here to prevent import errors when the kfp package is not installed. |
| 874 | + try: |
| 875 | + from kfp.v2 import compiler as compiler_v2 |
| 876 | + except ImportError as err: |
| 877 | + raise RuntimeError( |
| 878 | + "Cannot import the kfp.v2.compiler module. Please install or update the kfp package." |
| 879 | + ) from err |
| 880 | + |
| 881 | + automatic_display_name = " ".join( |
| 882 | + [ |
| 883 | + pipeline_func.__name__.replace("_", " "), |
| 884 | + datetime.datetime.now().isoformat(sep=" "), |
| 885 | + ] |
| 886 | + ) |
| 887 | + display_name = display_name or automatic_display_name |
| 888 | + job_id = job_id or re.sub( |
| 889 | + r"[^-a-z0-9]", "-", automatic_display_name.lower() |
| 890 | + ).strip("-") |
| 891 | + pipeline_file = tempfile.mktemp(suffix=".json") |
| 892 | + compiler_v2.Compiler().compile( |
| 893 | + pipeline_func=pipeline_func, |
| 894 | + pipeline_name=context_name, |
| 895 | + package_path=pipeline_file, |
| 896 | + ) |
| 897 | + pipeline_job = PipelineJob( |
| 898 | + template_path=pipeline_file, |
| 899 | + parameter_values=parameter_values, |
| 900 | + pipeline_root=output_artifacts_gcs_dir, |
| 901 | + enable_caching=enable_caching, |
| 902 | + display_name=display_name, |
| 903 | + job_id=job_id, |
| 904 | + labels=labels, |
| 905 | + project=project, |
| 906 | + location=location, |
| 907 | + credentials=credentials, |
| 908 | + encryption_spec_key_name=encryption_spec_key_name, |
| 909 | + ) |
| 910 | + return pipeline_job |
| 911 | + |
775 | 912 | def get_associated_experiment(self) -> Optional["aiplatform.Experiment"]:
|
776 | 913 | """Gets the aiplatform.Experiment associated with this PipelineJob,
|
777 | 914 | or None if this PipelineJob is not associated with an experiment.
|
|
0 commit comments