Skip to content

Commit f2e70b1

Browse files
feat(sdk): enable loading both JSON and YAML pipelines IR (#1089)
1 parent 936ddf8 commit f2e70b1

File tree

5 files changed

+277
-138
lines changed

5 files changed

+277
-138
lines changed

google/cloud/aiplatform/pipeline_jobs.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from google.cloud.aiplatform import base
2626
from google.cloud.aiplatform import initializer
2727
from google.cloud.aiplatform import utils
28-
from google.cloud.aiplatform.utils import json_utils
28+
from google.cloud.aiplatform.utils import yaml_utils
2929
from google.cloud.aiplatform.utils import pipeline_utils
3030
from google.protobuf import json_format
3131

@@ -112,7 +112,7 @@ def __init__(
112112
display_name (str):
113113
Required. The user-defined name of this Pipeline.
114114
template_path (str):
115-
Required. The path of PipelineJob or PipelineSpec JSON file. It
115+
Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
116116
can be a local path or a Google Cloud Storage URI.
117117
Example: "gs://project.name"
118118
job_id (str):
@@ -173,9 +173,12 @@ def __init__(
173173
self._parent = initializer.global_config.common_location_path(
174174
project=project, location=location
175175
)
176-
pipeline_json = json_utils.load_json(
176+
177+
# this loads both .yaml and .json files because YAML is a superset of JSON
178+
pipeline_json = yaml_utils.load_yaml(
177179
template_path, self.project, self.credentials
178180
)
181+
179182
# Pipeline_json can be either PipelineJob or PipelineSpec.
180183
if pipeline_json.get("pipelineSpec") is not None:
181184
pipeline_job = pipeline_json

google/cloud/aiplatform/utils/json_utils.py google/cloud/aiplatform/utils/yaml_utils.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -15,70 +15,83 @@
1515
# limitations under the License.
1616
#
1717

18-
import json
1918
from typing import Any, Dict, Optional
2019

2120
from google.auth import credentials as auth_credentials
2221
from google.cloud import storage
2322

2423

25-
def load_json(
24+
def load_yaml(
2625
path: str,
2726
project: Optional[str] = None,
2827
credentials: Optional[auth_credentials.Credentials] = None,
2928
) -> Dict[str, Any]:
30-
"""Loads data from a JSON document.
29+
"""Loads data from a YAML document.
3130
3231
Args:
3332
path (str):
34-
Required. The path of the JSON document in Google Cloud Storage or
33+
Required. The path of the YAML document in Google Cloud Storage or
3534
local.
3635
project (str):
3736
Optional. Project to initiate the Storage client with.
3837
credentials (auth_credentials.Credentials):
3938
Optional. Credentials to use with Storage Client.
4039
4140
Returns:
42-
A Dict object representing the JSON document.
41+
A Dict object representing the YAML document.
4342
"""
4443
if path.startswith("gs://"):
45-
return _load_json_from_gs_uri(path, project, credentials)
44+
return _load_yaml_from_gs_uri(path, project, credentials)
4645
else:
47-
return _load_json_from_local_file(path)
46+
return _load_yaml_from_local_file(path)
4847

4948

50-
def _load_json_from_gs_uri(
49+
def _load_yaml_from_gs_uri(
5150
uri: str,
5251
project: Optional[str] = None,
5352
credentials: Optional[auth_credentials.Credentials] = None,
5453
) -> Dict[str, Any]:
55-
"""Loads data from a JSON document referenced by a GCS URI.
54+
"""Loads data from a YAML document referenced by a GCS URI.
5655
5756
Args:
5857
path (str):
59-
Required. GCS URI for JSON document.
58+
Required. GCS URI for YAML document.
6059
project (str):
6160
Optional. Project to initiate the Storage client with.
6261
credentials (auth_credentials.Credentials):
6362
Optional. Credentials to use with Storage Client.
6463
6564
Returns:
66-
A Dict object representing the JSON document.
65+
A Dict object representing the YAML document.
6766
"""
67+
try:
68+
import yaml
69+
except ImportError:
70+
raise ImportError(
71+
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
72+
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
73+
)
6874
storage_client = storage.Client(project=project, credentials=credentials)
6975
blob = storage.Blob.from_string(uri, storage_client)
70-
return json.loads(blob.download_as_bytes())
76+
return yaml.safe_load(blob.download_as_bytes())
7177

7278

73-
def _load_json_from_local_file(file_path: str) -> Dict[str, Any]:
74-
"""Loads data from a JSON local file.
79+
def _load_yaml_from_local_file(file_path: str) -> Dict[str, Any]:
80+
"""Loads data from a YAML local file.
7581
7682
Args:
7783
file_path (str):
78-
Required. The local file path of the JSON document.
84+
Required. The local file path of the YAML document.
7985
8086
Returns:
81-
A Dict object representing the JSON document.
87+
A Dict object representing the YAML document.
8288
"""
89+
try:
90+
import yaml
91+
except ImportError:
92+
raise ImportError(
93+
"pyyaml is not installed and is required to parse PipelineJob or PipelineSpec files. "
94+
'Please install the SDK using "pip install google-cloud-aiplatform[pipelines]"'
95+
)
8396
with open(file_path) as f:
84-
return json.load(f)
97+
return yaml.safe_load(f)

setup.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,17 @@
5252
"pandas >= 1.0.0",
5353
"pyarrow >= 6.0.1",
5454
]
55-
55+
pipelines_extra_requires = [
56+
"pyyaml>=5.3,<6",
57+
]
5658
full_extra_require = list(
5759
set(
5860
tensorboard_extra_require
5961
+ metadata_extra_require
6062
+ xai_extra_require
6163
+ lit_extra_require
6264
+ featurestore_extra_require
65+
+ pipelines_extra_requires
6366
)
6467
)
6568
testing_extra_require = (
@@ -110,6 +113,7 @@
110113
"xai": xai_extra_require,
111114
"lit": lit_extra_require,
112115
"cloud_profiler": profiler_extra_require,
116+
"pipelines": pipelines_extra_requires,
113117
},
114118
python_requires=">=3.6",
115119
classifiers=[

0 commit comments

Comments
 (0)