15
15
# limitations under the License.
16
16
#
17
17
18
- from typing import List , Optional
18
+ import datetime
19
+ import re
20
+ from typing import Any , Dict , List , Optional
19
21
22
+ from google .auth import credentials as auth_credentials
23
+ from google .cloud import aiplatform_v1beta1
24
+ from google .cloud .aiplatform import base
25
+ from google .cloud .aiplatform import compat
26
+ from google .cloud .aiplatform import initializer
27
+ from google .cloud .aiplatform import pipeline_job_schedules
28
+ from google .cloud .aiplatform import utils
29
+ from google .cloud .aiplatform .constants import pipeline as pipeline_constants
30
+ from google .cloud .aiplatform .metadata import constants as metadata_constants
31
+ from google .cloud .aiplatform .metadata import experiment_resources
20
32
from google .cloud .aiplatform .pipeline_jobs import (
21
33
PipelineJob as PipelineJobGa ,
22
34
)
23
35
from google .cloud .aiplatform_v1 .services .pipeline_service import (
24
36
PipelineServiceClient as PipelineServiceClientGa ,
25
37
)
26
- from google .cloud import aiplatform_v1beta1
27
- from google .cloud .aiplatform import compat , pipeline_job_schedules
28
- from google .cloud .aiplatform import initializer
29
- from google .cloud .aiplatform import utils
30
38
31
- from google .cloud .aiplatform .metadata import constants as metadata_constants
32
- from google .cloud .aiplatform .metadata import experiment_resources
39
+ from google .protobuf import json_format
40
+
41
+
42
+ _LOGGER = base .Logger (__name__ )
43
+
44
+ # Pattern for valid names used as a Vertex resource name.
45
+ _VALID_NAME_PATTERN = pipeline_constants ._VALID_NAME_PATTERN
46
+
47
+ # Pattern for an Artifact Registry URL.
48
+ _VALID_AR_URL = pipeline_constants ._VALID_AR_URL
49
+
50
+ # Pattern for any JSON or YAML file over HTTPS.
51
+ _VALID_HTTPS_URL = pipeline_constants ._VALID_HTTPS_URL
52
+
53
+
54
+ def _get_current_time () -> datetime .datetime :
55
+ """Gets the current timestamp."""
56
+ return datetime .datetime .now ()
57
+
58
+
59
+ def _set_enable_caching_value (
60
+ pipeline_spec : Dict [str , Any ], enable_caching : bool
61
+ ) -> None :
62
+ """Sets pipeline tasks caching options.
63
+
64
+ Args:
65
+ pipeline_spec (Dict[str, Any]):
66
+ Required. The dictionary of pipeline spec.
67
+ enable_caching (bool):
68
+ Required. Whether to enable caching.
69
+ """
70
+ for component in [pipeline_spec ["root" ]] + list (
71
+ pipeline_spec ["components" ].values ()
72
+ ):
73
+ if "dag" in component :
74
+ for task in component ["dag" ]["tasks" ].values ():
75
+ task ["cachingOptions" ] = {"enableCache" : enable_caching }
33
76
34
77
35
78
class _PipelineJob (
@@ -42,6 +85,192 @@ class _PipelineJob(
42
85
):
43
86
"""Preview PipelineJob resource for Vertex AI."""
44
87
88
+ def __init__ (
89
+ self ,
90
+ display_name : str ,
91
+ template_path : str ,
92
+ job_id : Optional [str ] = None ,
93
+ pipeline_root : Optional [str ] = None ,
94
+ parameter_values : Optional [Dict [str , Any ]] = None ,
95
+ input_artifacts : Optional [Dict [str , str ]] = None ,
96
+ enable_caching : Optional [bool ] = None ,
97
+ encryption_spec_key_name : Optional [str ] = None ,
98
+ labels : Optional [Dict [str , str ]] = None ,
99
+ credentials : Optional [auth_credentials .Credentials ] = None ,
100
+ project : Optional [str ] = None ,
101
+ location : Optional [str ] = None ,
102
+ failure_policy : Optional [str ] = None ,
103
+ enable_preflight_validations : Optional [bool ] = False ,
104
+ ):
105
+ """Retrieves a PipelineJob resource and instantiates its
106
+ representation.
107
+
108
+ Args:
109
+ display_name (str):
110
+ Required. The user-defined name of this Pipeline.
111
+ template_path (str):
112
+ Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
113
+ can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
114
+ an Artifact Registry URI (e.g.
115
+ "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI.
116
+ job_id (str):
117
+ Optional. The unique ID of the job run.
118
+ If not specified, pipeline name + timestamp will be used.
119
+ pipeline_root (str):
120
+ Optional. The root of the pipeline outputs. If not set, the staging bucket
121
+ set in aiplatform.init will be used. If that's not set a pipeline-specific
122
+ artifacts bucket will be used.
123
+ parameter_values (Dict[str, Any]):
124
+ Optional. The mapping from runtime parameter names to its values that
125
+ control the pipeline run.
126
+ input_artifacts (Dict[str, str]):
127
+ Optional. The mapping from the runtime parameter name for this artifact to its resource id.
128
+ For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
129
+ enable_caching (bool):
130
+ Optional. Whether to turn on caching for the run.
131
+
132
+ If this is not set, defaults to the compile time settings, which
133
+ are True for all tasks by default, while users may specify
134
+ different caching options for individual tasks.
135
+
136
+ If this is set, the setting applies to all tasks in the pipeline.
137
+
138
+ Overrides the compile time settings.
139
+ encryption_spec_key_name (str):
140
+ Optional. The Cloud KMS resource identifier of the customer
141
+ managed encryption key used to protect the job. Has the
142
+ form:
143
+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
144
+ The key needs to be in the same region as where the compute
145
+ resource is created.
146
+
147
+ If this is set, then all
148
+ resources created by the PipelineJob will
149
+ be encrypted with the provided encryption key.
150
+
151
+ Overrides encryption_spec_key_name set in aiplatform.init.
152
+ labels (Dict[str, str]):
153
+ Optional. The user defined metadata to organize PipelineJob.
154
+ credentials (auth_credentials.Credentials):
155
+ Optional. Custom credentials to use to create this PipelineJob.
156
+ Overrides credentials set in aiplatform.init.
157
+ project (str):
158
+ Optional. The project that you want to run this PipelineJob in. If not set,
159
+ the project set in aiplatform.init will be used.
160
+ location (str):
161
+ Optional. Location to create PipelineJob. If not set,
162
+ location set in aiplatform.init will be used.
163
+ failure_policy (str):
164
+ Optional. The failure policy - "slow" or "fast".
165
+ Currently, the default of a pipeline is that the pipeline will continue to
166
+ run until no more tasks can be executed, also known as
167
+ PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow").
168
+ However, if a pipeline is set to
169
+ PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"),
170
+ it will stop scheduling any new tasks when a task has failed. Any
171
+ scheduled tasks will continue to completion.
172
+ enable_preflight_validations (bool):
173
+ Optional. Whether to enable preflight validations or not.
174
+
175
+ Raises:
176
+ ValueError: If job_id or labels have incorrect format.
177
+ """
178
+
179
+ super ().__init__ (
180
+ display_name = display_name ,
181
+ template_path = template_path ,
182
+ job_id = job_id ,
183
+ pipeline_root = pipeline_root ,
184
+ parameter_values = parameter_values ,
185
+ input_artifacts = input_artifacts ,
186
+ enable_caching = enable_caching ,
187
+ encryption_spec_key_name = encryption_spec_key_name ,
188
+ labels = labels ,
189
+ credentials = credentials ,
190
+ project = project ,
191
+ location = location ,
192
+ failure_policy = failure_policy ,
193
+ )
194
+
195
+ # needs to rebuild the v1beta version of pipeline_job and runtime_config
196
+ pipeline_json = utils .yaml_utils .load_yaml (
197
+ template_path , self .project , self .credentials
198
+ )
199
+
200
+ # Pipeline_json can be either PipelineJob or PipelineSpec.
201
+ if pipeline_json .get ("pipelineSpec" ) is not None :
202
+ pipeline_job = pipeline_json
203
+ pipeline_root = (
204
+ pipeline_root
205
+ or pipeline_job ["pipelineSpec" ].get ("defaultPipelineRoot" )
206
+ or pipeline_job ["runtimeConfig" ].get ("gcsOutputDirectory" )
207
+ or initializer .global_config .staging_bucket
208
+ )
209
+ else :
210
+ pipeline_job = {
211
+ "pipelineSpec" : pipeline_json ,
212
+ "runtimeConfig" : {},
213
+ }
214
+ pipeline_root = (
215
+ pipeline_root
216
+ or pipeline_job ["pipelineSpec" ].get ("defaultPipelineRoot" )
217
+ or initializer .global_config .staging_bucket
218
+ )
219
+ pipeline_root = (
220
+ pipeline_root
221
+ or utils .gcs_utils .generate_gcs_directory_for_pipeline_artifacts (
222
+ project = project ,
223
+ location = location ,
224
+ )
225
+ )
226
+ builder = utils .pipeline_utils .PipelineRuntimeConfigBuilder .from_job_spec_json (
227
+ pipeline_job
228
+ )
229
+ builder .update_pipeline_root (pipeline_root )
230
+ builder .update_runtime_parameters (parameter_values )
231
+ builder .update_input_artifacts (input_artifacts )
232
+
233
+ builder .update_failure_policy (failure_policy )
234
+ runtime_config_dict = builder .build ()
235
+
236
+ runtime_config = aiplatform_v1beta1 .PipelineJob .RuntimeConfig ()._pb
237
+ json_format .ParseDict (runtime_config_dict , runtime_config )
238
+
239
+ pipeline_name = pipeline_job ["pipelineSpec" ]["pipelineInfo" ]["name" ]
240
+ self .job_id = job_id or "{pipeline_name}-{timestamp}" .format (
241
+ pipeline_name = re .sub ("[^-0-9a-z]+" , "-" , pipeline_name .lower ())
242
+ .lstrip ("-" )
243
+ .rstrip ("-" ),
244
+ timestamp = _get_current_time ().strftime ("%Y%m%d%H%M%S" ),
245
+ )
246
+ if not _VALID_NAME_PATTERN .match (self .job_id ):
247
+ raise ValueError (
248
+ f"Generated job ID: { self .job_id } is illegal as a Vertex pipelines job ID. "
249
+ "Expecting an ID following the regex pattern "
250
+ f'"{ _VALID_NAME_PATTERN .pattern [1 :- 1 ]} "'
251
+ )
252
+
253
+ if enable_caching is not None :
254
+ _set_enable_caching_value (pipeline_job ["pipelineSpec" ], enable_caching )
255
+
256
+ pipeline_job_args = {
257
+ "display_name" : display_name ,
258
+ "pipeline_spec" : pipeline_job ["pipelineSpec" ],
259
+ "labels" : labels ,
260
+ "runtime_config" : runtime_config ,
261
+ "encryption_spec" : initializer .global_config .get_encryption_spec (
262
+ encryption_spec_key_name = encryption_spec_key_name
263
+ ),
264
+ "preflight_validations" : enable_preflight_validations ,
265
+ }
266
+
267
+ if _VALID_AR_URL .match (template_path ) or _VALID_HTTPS_URL .match (template_path ):
268
+ pipeline_job_args ["template_uri" ] = template_path
269
+
270
+ self ._v1_beta1_pipeline_job = aiplatform_v1beta1 .PipelineJob (
271
+ ** pipeline_job_args
272
+ )
273
+
45
274
def create_schedule (
46
275
self ,
47
276
cron_expression : str ,
@@ -180,3 +409,79 @@ def batch_delete(
180
409
v1beta1_client = client .select_version (compat .V1BETA1 )
181
410
operation = v1beta1_client .batch_delete_pipeline_jobs (request )
182
411
return operation .result ()
412
+
413
+ def submit (
414
+ self ,
415
+ service_account : Optional [str ] = None ,
416
+ network : Optional [str ] = None ,
417
+ reserved_ip_ranges : Optional [List [str ]] = None ,
418
+ create_request_timeout : Optional [float ] = None ,
419
+ job_id : Optional [str ] = None ,
420
+ ) -> None :
421
+ """Run this configured PipelineJob.
422
+
423
+ Args:
424
+ service_account (str):
425
+ Optional. Specifies the service account for workload run-as account.
426
+ Users submitting jobs must have act-as permission on this run-as account.
427
+ network (str):
428
+ Optional. The full name of the Compute Engine network to which the job
429
+ should be peered. For example, projects/12345/global/networks/myVPC.
430
+
431
+ Private services access must already be configured for the network.
432
+ If left unspecified, the network set in aiplatform.init will be used.
433
+ Otherwise, the job is not peered with any network.
434
+ reserved_ip_ranges (List[str]):
435
+ Optional. A list of names for the reserved IP ranges under the VPC
436
+ network that can be used for this PipelineJob's workload. For example: ['vertex-ai-ip-range'].
437
+
438
+ If left unspecified, the job will be deployed to any IP ranges under
439
+ the provided VPC network.
440
+ create_request_timeout (float):
441
+ Optional. The timeout for the create request in seconds.
442
+ job_id (str):
443
+ Optional. The ID to use for the PipelineJob, which will become the final
444
+ component of the PipelineJob name. If not provided, an ID will be
445
+ automatically generated.
446
+ """
447
+ network = network or initializer .global_config .network
448
+ service_account = service_account or initializer .global_config .service_account
449
+ gca_resouce = self ._v1_beta1_pipeline_job
450
+
451
+ if service_account :
452
+ gca_resouce .service_account = service_account
453
+
454
+ if network :
455
+ gca_resouce .network = network
456
+
457
+ if reserved_ip_ranges :
458
+ gca_resouce .reserved_ip_ranges = reserved_ip_ranges
459
+ user_project = initializer .global_config .project
460
+ user_location = initializer .global_config .location
461
+ parent = initializer .global_config .common_location_path (
462
+ project = user_project , location = user_location
463
+ )
464
+
465
+ client = self ._instantiate_client (
466
+ location = user_location ,
467
+ appended_user_agent = ["preview-pipeline-job-submit" ],
468
+ )
469
+ v1beta1_client = client .select_version (compat .V1BETA1 )
470
+
471
+ _LOGGER .log_create_with_lro (self .__class__ )
472
+
473
+ request = aiplatform_v1beta1 .CreatePipelineJobRequest (
474
+ parent = parent ,
475
+ pipeline_job = self ._v1_beta1_pipeline_job ,
476
+ pipeline_job_id = job_id or self .job_id ,
477
+ )
478
+
479
+ response = v1beta1_client .create_pipeline_job (request = request )
480
+
481
+ self ._gca_resource = response
482
+
483
+ _LOGGER .log_create_complete_with_getter (
484
+ self .__class__ , self ._gca_resource , "pipeline_job"
485
+ )
486
+
487
+ _LOGGER .info ("View Pipeline Job:\n %s" % self ._dashboard_uri ())
0 commit comments