Skip to content

Commit e88dc0d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Support PreflightValidation in Preview PipelineJob submit function.
PiperOrigin-RevId: 628707894
1 parent 1341e2c commit e88dc0d

File tree

2 files changed

+358
-7
lines changed

2 files changed

+358
-7
lines changed

google/cloud/aiplatform/preview/pipelinejob/pipeline_jobs.py

+312-7
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,64 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import List, Optional
18+
import datetime
19+
import re
20+
from typing import Any, Dict, List, Optional
1921

22+
from google.auth import credentials as auth_credentials
23+
from google.cloud import aiplatform_v1beta1
24+
from google.cloud.aiplatform import base
25+
from google.cloud.aiplatform import compat
26+
from google.cloud.aiplatform import initializer
27+
from google.cloud.aiplatform import pipeline_job_schedules
28+
from google.cloud.aiplatform import utils
29+
from google.cloud.aiplatform.constants import pipeline as pipeline_constants
30+
from google.cloud.aiplatform.metadata import constants as metadata_constants
31+
from google.cloud.aiplatform.metadata import experiment_resources
2032
from google.cloud.aiplatform.pipeline_jobs import (
2133
PipelineJob as PipelineJobGa,
2234
)
2335
from google.cloud.aiplatform_v1.services.pipeline_service import (
2436
PipelineServiceClient as PipelineServiceClientGa,
2537
)
26-
from google.cloud import aiplatform_v1beta1
27-
from google.cloud.aiplatform import compat, pipeline_job_schedules
28-
from google.cloud.aiplatform import initializer
29-
from google.cloud.aiplatform import utils
3038

31-
from google.cloud.aiplatform.metadata import constants as metadata_constants
32-
from google.cloud.aiplatform.metadata import experiment_resources
39+
from google.protobuf import json_format
40+
41+
42+
_LOGGER = base.Logger(__name__)
43+
44+
# Pattern for valid names used as a Vertex resource name.
45+
_VALID_NAME_PATTERN = pipeline_constants._VALID_NAME_PATTERN
46+
47+
# Pattern for an Artifact Registry URL.
48+
_VALID_AR_URL = pipeline_constants._VALID_AR_URL
49+
50+
# Pattern for any JSON or YAML file over HTTPS.
51+
_VALID_HTTPS_URL = pipeline_constants._VALID_HTTPS_URL
52+
53+
54+
def _get_current_time() -> datetime.datetime:
55+
"""Gets the current timestamp."""
56+
return datetime.datetime.now()
57+
58+
59+
def _set_enable_caching_value(
60+
pipeline_spec: Dict[str, Any], enable_caching: bool
61+
) -> None:
62+
"""Sets pipeline tasks caching options.
63+
64+
Args:
65+
pipeline_spec (Dict[str, Any]):
66+
Required. The dictionary of pipeline spec.
67+
enable_caching (bool):
68+
Required. Whether to enable caching.
69+
"""
70+
for component in [pipeline_spec["root"]] + list(
71+
pipeline_spec["components"].values()
72+
):
73+
if "dag" in component:
74+
for task in component["dag"]["tasks"].values():
75+
task["cachingOptions"] = {"enableCache": enable_caching}
3376

3477

3578
class _PipelineJob(
@@ -42,6 +85,192 @@ class _PipelineJob(
4285
):
4386
"""Preview PipelineJob resource for Vertex AI."""
4487

88+
def __init__(
89+
self,
90+
display_name: str,
91+
template_path: str,
92+
job_id: Optional[str] = None,
93+
pipeline_root: Optional[str] = None,
94+
parameter_values: Optional[Dict[str, Any]] = None,
95+
input_artifacts: Optional[Dict[str, str]] = None,
96+
enable_caching: Optional[bool] = None,
97+
encryption_spec_key_name: Optional[str] = None,
98+
labels: Optional[Dict[str, str]] = None,
99+
credentials: Optional[auth_credentials.Credentials] = None,
100+
project: Optional[str] = None,
101+
location: Optional[str] = None,
102+
failure_policy: Optional[str] = None,
103+
enable_preflight_validations: Optional[bool] = False,
104+
):
105+
"""Retrieves a PipelineJob resource and instantiates its
106+
representation.
107+
108+
Args:
109+
display_name (str):
110+
Required. The user-defined name of this Pipeline.
111+
template_path (str):
112+
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
113+
can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
114+
an Artifact Registry URI (e.g.
115+
"https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI.
116+
job_id (str):
117+
Optional. The unique ID of the job run.
118+
If not specified, pipeline name + timestamp will be used.
119+
pipeline_root (str):
120+
Optional. The root of the pipeline outputs. If not set, the staging bucket
121+
set in aiplatform.init will be used. If that's not set a pipeline-specific
122+
artifacts bucket will be used.
123+
parameter_values (Dict[str, Any]):
124+
Optional. The mapping from runtime parameter names to its values that
125+
control the pipeline run.
126+
input_artifacts (Dict[str, str]):
127+
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
128+
For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
129+
enable_caching (bool):
130+
Optional. Whether to turn on caching for the run.
131+
132+
If this is not set, defaults to the compile time settings, which
133+
are True for all tasks by default, while users may specify
134+
different caching options for individual tasks.
135+
136+
If this is set, the setting applies to all tasks in the pipeline.
137+
138+
Overrides the compile time settings.
139+
encryption_spec_key_name (str):
140+
Optional. The Cloud KMS resource identifier of the customer
141+
managed encryption key used to protect the job. Has the
142+
form:
143+
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
144+
The key needs to be in the same region as where the compute
145+
resource is created.
146+
147+
If this is set, then all
148+
resources created by the PipelineJob will
149+
be encrypted with the provided encryption key.
150+
151+
Overrides encryption_spec_key_name set in aiplatform.init.
152+
labels (Dict[str, str]):
153+
Optional. The user defined metadata to organize PipelineJob.
154+
credentials (auth_credentials.Credentials):
155+
Optional. Custom credentials to use to create this PipelineJob.
156+
Overrides credentials set in aiplatform.init.
157+
project (str):
158+
Optional. The project that you want to run this PipelineJob in. If not set,
159+
the project set in aiplatform.init will be used.
160+
location (str):
161+
Optional. Location to create PipelineJob. If not set,
162+
location set in aiplatform.init will be used.
163+
failure_policy (str):
164+
Optional. The failure policy - "slow" or "fast".
165+
Currently, the default of a pipeline is that the pipeline will continue to
166+
run until no more tasks can be executed, also known as
167+
PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow").
168+
However, if a pipeline is set to
169+
PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"),
170+
it will stop scheduling any new tasks when a task has failed. Any
171+
scheduled tasks will continue to completion.
172+
enable_preflight_validations (bool):
173+
Optional. Whether to enable preflight validations or not.
174+
175+
Raises:
176+
ValueError: If job_id or labels have incorrect format.
177+
"""
178+
179+
super().__init__(
180+
display_name=display_name,
181+
template_path=template_path,
182+
job_id=job_id,
183+
pipeline_root=pipeline_root,
184+
parameter_values=parameter_values,
185+
input_artifacts=input_artifacts,
186+
enable_caching=enable_caching,
187+
encryption_spec_key_name=encryption_spec_key_name,
188+
labels=labels,
189+
credentials=credentials,
190+
project=project,
191+
location=location,
192+
failure_policy=failure_policy,
193+
)
194+
195+
# needs to rebuild the v1beta version of pipeline_job and runtime_config
196+
pipeline_json = utils.yaml_utils.load_yaml(
197+
template_path, self.project, self.credentials
198+
)
199+
200+
# Pipeline_json can be either PipelineJob or PipelineSpec.
201+
if pipeline_json.get("pipelineSpec") is not None:
202+
pipeline_job = pipeline_json
203+
pipeline_root = (
204+
pipeline_root
205+
or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
206+
or pipeline_job["runtimeConfig"].get("gcsOutputDirectory")
207+
or initializer.global_config.staging_bucket
208+
)
209+
else:
210+
pipeline_job = {
211+
"pipelineSpec": pipeline_json,
212+
"runtimeConfig": {},
213+
}
214+
pipeline_root = (
215+
pipeline_root
216+
or pipeline_job["pipelineSpec"].get("defaultPipelineRoot")
217+
or initializer.global_config.staging_bucket
218+
)
219+
pipeline_root = (
220+
pipeline_root
221+
or utils.gcs_utils.generate_gcs_directory_for_pipeline_artifacts(
222+
project=project,
223+
location=location,
224+
)
225+
)
226+
builder = utils.pipeline_utils.PipelineRuntimeConfigBuilder.from_job_spec_json(
227+
pipeline_job
228+
)
229+
builder.update_pipeline_root(pipeline_root)
230+
builder.update_runtime_parameters(parameter_values)
231+
builder.update_input_artifacts(input_artifacts)
232+
233+
builder.update_failure_policy(failure_policy)
234+
runtime_config_dict = builder.build()
235+
236+
runtime_config = aiplatform_v1beta1.PipelineJob.RuntimeConfig()._pb
237+
json_format.ParseDict(runtime_config_dict, runtime_config)
238+
239+
pipeline_name = pipeline_job["pipelineSpec"]["pipelineInfo"]["name"]
240+
self.job_id = job_id or "{pipeline_name}-{timestamp}".format(
241+
pipeline_name=re.sub("[^-0-9a-z]+", "-", pipeline_name.lower())
242+
.lstrip("-")
243+
.rstrip("-"),
244+
timestamp=_get_current_time().strftime("%Y%m%d%H%M%S"),
245+
)
246+
if not _VALID_NAME_PATTERN.match(self.job_id):
247+
raise ValueError(
248+
f"Generated job ID: {self.job_id} is illegal as a Vertex pipelines job ID. "
249+
"Expecting an ID following the regex pattern "
250+
f'"{_VALID_NAME_PATTERN.pattern[1:-1]}"'
251+
)
252+
253+
if enable_caching is not None:
254+
_set_enable_caching_value(pipeline_job["pipelineSpec"], enable_caching)
255+
256+
pipeline_job_args = {
257+
"display_name": display_name,
258+
"pipeline_spec": pipeline_job["pipelineSpec"],
259+
"labels": labels,
260+
"runtime_config": runtime_config,
261+
"encryption_spec": initializer.global_config.get_encryption_spec(
262+
encryption_spec_key_name=encryption_spec_key_name
263+
),
264+
"preflight_validations": enable_preflight_validations,
265+
}
266+
267+
if _VALID_AR_URL.match(template_path) or _VALID_HTTPS_URL.match(template_path):
268+
pipeline_job_args["template_uri"] = template_path
269+
270+
self._v1_beta1_pipeline_job = aiplatform_v1beta1.PipelineJob(
271+
**pipeline_job_args
272+
)
273+
45274
def create_schedule(
46275
self,
47276
cron_expression: str,
@@ -180,3 +409,79 @@ def batch_delete(
180409
v1beta1_client = client.select_version(compat.V1BETA1)
181410
operation = v1beta1_client.batch_delete_pipeline_jobs(request)
182411
return operation.result()
412+
413+
def submit(
414+
self,
415+
service_account: Optional[str] = None,
416+
network: Optional[str] = None,
417+
reserved_ip_ranges: Optional[List[str]] = None,
418+
create_request_timeout: Optional[float] = None,
419+
job_id: Optional[str] = None,
420+
) -> None:
421+
"""Run this configured PipelineJob.
422+
423+
Args:
424+
service_account (str):
425+
Optional. Specifies the service account for workload run-as account.
426+
Users submitting jobs must have act-as permission on this run-as account.
427+
network (str):
428+
Optional. The full name of the Compute Engine network to which the job
429+
should be peered. For example, projects/12345/global/networks/myVPC.
430+
431+
Private services access must already be configured for the network.
432+
If left unspecified, the network set in aiplatform.init will be used.
433+
Otherwise, the job is not peered with any network.
434+
reserved_ip_ranges (List[str]):
435+
Optional. A list of names for the reserved IP ranges under the VPC
436+
network that can be used for this PipelineJob's workload. For example: ['vertex-ai-ip-range'].
437+
438+
If left unspecified, the job will be deployed to any IP ranges under
439+
the provided VPC network.
440+
create_request_timeout (float):
441+
Optional. The timeout for the create request in seconds.
442+
job_id (str):
443+
Optional. The ID to use for the PipelineJob, which will become the final
444+
component of the PipelineJob name. If not provided, an ID will be
445+
automatically generated.
446+
"""
447+
network = network or initializer.global_config.network
448+
service_account = service_account or initializer.global_config.service_account
449+
gca_resouce = self._v1_beta1_pipeline_job
450+
451+
if service_account:
452+
gca_resouce.service_account = service_account
453+
454+
if network:
455+
gca_resouce.network = network
456+
457+
if reserved_ip_ranges:
458+
gca_resouce.reserved_ip_ranges = reserved_ip_ranges
459+
user_project = initializer.global_config.project
460+
user_location = initializer.global_config.location
461+
parent = initializer.global_config.common_location_path(
462+
project=user_project, location=user_location
463+
)
464+
465+
client = self._instantiate_client(
466+
location=user_location,
467+
appended_user_agent=["preview-pipeline-job-submit"],
468+
)
469+
v1beta1_client = client.select_version(compat.V1BETA1)
470+
471+
_LOGGER.log_create_with_lro(self.__class__)
472+
473+
request = aiplatform_v1beta1.CreatePipelineJobRequest(
474+
parent=parent,
475+
pipeline_job=self._v1_beta1_pipeline_job,
476+
pipeline_job_id=job_id or self.job_id,
477+
)
478+
479+
response = v1beta1_client.create_pipeline_job(request=request)
480+
481+
self._gca_resource = response
482+
483+
_LOGGER.log_create_complete_with_getter(
484+
self.__class__, self._gca_resource, "pipeline_job"
485+
)
486+
487+
_LOGGER.info("View Pipeline Job:\n%s" % self._dashboard_uri())

0 commit comments

Comments
 (0)