Skip to content

Commit 95b107c

Browse files
authored
feat: Change the Metadata SDK _Context class to an external class (#1519)
* feat: Change the Metadata SDK _Context class to an external class * Add base schema class for context * Add additional context schema types * Add additional context schema types * Add create method to Context. * Fix unit test failure. * add unit tests * fix lint issue * Add Context to root __init__. * correct import path
1 parent fd55daf commit 95b107c

11 files changed

+575
-46
lines changed

google/cloud/aiplatform/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
ExperimentRun = metadata.experiment_run_resource.ExperimentRun
9696
Artifact = metadata.artifact.Artifact
9797
Execution = metadata.execution.Execution
98+
Context = metadata.context.Context
9899

99100

100101
__all__ = (

google/cloud/aiplatform/metadata/context.py

+152-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import proto
2121

22+
from google.auth import credentials as auth_credentials
23+
2224
from google.cloud.aiplatform import base
2325
from google.cloud.aiplatform import utils
2426
from google.cloud.aiplatform.metadata import utils as metadata_utils
@@ -31,10 +33,11 @@
3133
)
3234
from google.cloud.aiplatform.metadata import artifact
3335
from google.cloud.aiplatform.metadata import execution
36+
from google.cloud.aiplatform.metadata import metadata_store
3437
from google.cloud.aiplatform.metadata import resource
3538

3639

37-
class _Context(resource._Resource):
40+
class Context(resource._Resource):
3841
"""Metadata Context resource for Vertex AI"""
3942

4043
_resource_noun = "contexts"
@@ -81,6 +84,153 @@ def get_artifacts(self) -> List[artifact.Artifact]:
8184
credentials=self.credentials,
8285
)
8386

87+
@classmethod
88+
def create(
89+
cls,
90+
schema_title: str,
91+
*,
92+
resource_id: Optional[str] = None,
93+
display_name: Optional[str] = None,
94+
schema_version: Optional[str] = None,
95+
description: Optional[str] = None,
96+
metadata: Optional[Dict] = None,
97+
metadata_store_id: Optional[str] = "default",
98+
project: Optional[str] = None,
99+
location: Optional[str] = None,
100+
credentials: Optional[auth_credentials.Credentials] = None,
101+
) -> "Context":
102+
"""Creates a new Metadata Context.
103+
104+
Args:
105+
schema_title (str):
106+
Required. schema_title identifies the schema title used by the Context.
107+
Please reference https://cloud.google.com/vertex-ai/docs/ml-metadata/system-schemas.
108+
resource_id (str):
109+
Optional. The <resource_id> portion of the Context name with
110+
the format. This is globally unique in a metadataStore:
111+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>.
112+
display_name (str):
113+
Optional. The user-defined name of the Context.
114+
schema_version (str):
115+
Optional. schema_version specifies the version used by the Context.
116+
If not set, defaults to use the latest version.
117+
description (str):
118+
Optional. Describes the purpose of the Context to be created.
119+
metadata (Dict):
120+
Optional. Contains the metadata information that will be stored in the Context.
121+
metadata_store_id (str):
122+
Optional. The <metadata_store_id> portion of the resource name with
123+
the format:
124+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/Contexts/<resource_id>
125+
If not provided, the MetadataStore's ID will be set to "default".
126+
project (str):
127+
Optional. Project used to create this Context. Overrides project set in
128+
aiplatform.init.
129+
location (str):
130+
Optional. Location used to create this Context. Overrides location set in
131+
aiplatform.init.
132+
credentials (auth_credentials.Credentials):
133+
Optional. Custom credentials used to create this Context. Overrides
134+
credentials set in aiplatform.init.
135+
136+
Returns:
137+
Context: Instantiated representation of the managed Metadata Context.
138+
"""
139+
return cls._create(
140+
resource_id=resource_id,
141+
schema_title=schema_title,
142+
display_name=display_name,
143+
schema_version=schema_version,
144+
description=description,
145+
metadata=metadata,
146+
metadata_store_id=metadata_store_id,
147+
project=project,
148+
location=location,
149+
credentials=credentials,
150+
)
151+
152+
# TODO() refactor code to move _create to _Resource class.
153+
@classmethod
154+
def _create(
155+
cls,
156+
resource_id: str,
157+
schema_title: str,
158+
display_name: Optional[str] = None,
159+
schema_version: Optional[str] = None,
160+
description: Optional[str] = None,
161+
metadata: Optional[Dict] = None,
162+
metadata_store_id: Optional[str] = "default",
163+
project: Optional[str] = None,
164+
location: Optional[str] = None,
165+
credentials: Optional[auth_credentials.Credentials] = None,
166+
) -> "Context":
167+
"""Creates a new Metadata resource.
168+
169+
Args:
170+
resource_id (str):
171+
Required. The <resource_id> portion of the resource name with
172+
the format:
173+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>.
174+
schema_title (str):
175+
Required. schema_title identifies the schema title used by the resource.
176+
display_name (str):
177+
Optional. The user-defined name of the resource.
178+
schema_version (str):
179+
Optional. schema_version specifies the version used by the resource.
180+
If not set, defaults to use the latest version.
181+
description (str):
182+
Optional. Describes the purpose of the resource to be created.
183+
metadata (Dict):
184+
Optional. Contains the metadata information that will be stored in the resource.
185+
metadata_store_id (str):
186+
The <metadata_store_id> portion of the resource name with
187+
the format:
188+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/<resource_noun>/<resource_id>
189+
If not provided, the MetadataStore's ID will be set to "default".
190+
project (str):
191+
Project used to create this resource. Overrides project set in
192+
aiplatform.init.
193+
location (str):
194+
Location used to create this resource. Overrides location set in
195+
aiplatform.init.
196+
credentials (auth_credentials.Credentials):
197+
Custom credentials used to create this resource. Overrides
198+
credentials set in aiplatform.init.
199+
200+
Returns:
201+
resource (_Resource):
202+
Instantiated representation of the managed Metadata resource.
203+
204+
"""
205+
api_client = cls._instantiate_client(location=location, credentials=credentials)
206+
207+
parent = utils.full_resource_name(
208+
resource_name=metadata_store_id,
209+
resource_noun=metadata_store._MetadataStore._resource_noun,
210+
parse_resource_name_method=metadata_store._MetadataStore._parse_resource_name,
211+
format_resource_name_method=metadata_store._MetadataStore._format_resource_name,
212+
project=project,
213+
location=location,
214+
)
215+
216+
resource = cls._create_resource(
217+
client=api_client,
218+
parent=parent,
219+
resource_id=resource_id,
220+
schema_title=schema_title,
221+
display_name=display_name,
222+
schema_version=schema_version,
223+
description=description,
224+
metadata=metadata,
225+
)
226+
227+
self = cls._empty_constructor(
228+
project=project, location=location, credentials=credentials
229+
)
230+
self._gca_resource = resource
231+
232+
return self
233+
84234
@classmethod
85235
def _create_resource(
86236
cls,
@@ -147,7 +297,7 @@ def _list_resources(
147297
)
148298
return client.list_contexts(request=list_request)
149299

150-
def add_context_children(self, contexts: List["_Context"]):
300+
def add_context_children(self, contexts: List["Context"]):
151301
"""Adds the provided contexts as children of this context.
152302
153303
Args:

google/cloud/aiplatform/metadata/experiment_resources.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,13 @@ def __init__(
119119
)
120120

121121
with _SetLoggerLevel(resource):
122-
experiment_context = context._Context(**metadata_args)
122+
experiment_context = context.Context(**metadata_args)
123123
self._validate_experiment_context(experiment_context)
124124

125125
self._metadata_context = experiment_context
126126

127127
@staticmethod
128-
def _validate_experiment_context(experiment_context: context._Context):
128+
def _validate_experiment_context(experiment_context: context.Context):
129129
"""Validates this context is an experiment context.
130130
131131
Args:
@@ -146,7 +146,7 @@ def _validate_experiment_context(experiment_context: context._Context):
146146
)
147147

148148
@staticmethod
149-
def _is_tensorboard_experiment(context: context._Context) -> bool:
149+
def _is_tensorboard_experiment(context: context.Context) -> bool:
150150
"""Returns True if Experiment is a Tensorboard Experiment created by CustomJob."""
151151
return constants.TENSORBOARD_CUSTOM_JOB_EXPERIMENT_FIELD in context.metadata
152152

@@ -192,7 +192,7 @@ def create(
192192
)
193193

194194
with _SetLoggerLevel(resource):
195-
experiment_context = context._Context._create(
195+
experiment_context = context.Context._create(
196196
resource_id=experiment_name,
197197
display_name=experiment_name,
198198
description=description,
@@ -248,7 +248,7 @@ def get_or_create(
248248
)
249249

250250
with _SetLoggerLevel(resource):
251-
experiment_context = context._Context.get_or_create(
251+
experiment_context = context.Context.get_or_create(
252252
resource_id=experiment_name,
253253
display_name=experiment_name,
254254
description=description,
@@ -303,7 +303,7 @@ def list(
303303
)
304304

305305
with _SetLoggerLevel(resource):
306-
experiment_contexts = context._Context.list(
306+
experiment_contexts = context.Context.list(
307307
filter=filter_str,
308308
project=project,
309309
location=location,
@@ -341,7 +341,7 @@ def delete(self, *, delete_backing_tensorboard_runs: bool = False):
341341
runs under this experiment that we used to store time series metrics.
342342
"""
343343

344-
experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context._Context][
344+
experiment_runs = _SUPPORTED_LOGGABLE_RESOURCES[context.Context][
345345
constants.SYSTEM_EXPERIMENT_RUN
346346
].list(experiment=self)
347347
for experiment_run in experiment_runs:
@@ -380,11 +380,11 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821
380380

381381
filter_str = metadata_utils._make_filter_string(
382382
schema_title=sorted(
383-
list(_SUPPORTED_LOGGABLE_RESOURCES[context._Context].keys())
383+
list(_SUPPORTED_LOGGABLE_RESOURCES[context.Context].keys())
384384
),
385385
parent_contexts=[self._metadata_context.resource_name],
386386
)
387-
contexts = context._Context.list(filter_str, **service_request_args)
387+
contexts = context.Context.list(filter_str, **service_request_args)
388388

389389
filter_str = metadata_utils._make_filter_string(
390390
schema_title=list(
@@ -398,7 +398,7 @@ def get_data_frame(self) -> "pd.DataFrame": # noqa: F821
398398
rows = []
399399
for metadata_context in contexts:
400400
row_dict = (
401-
_SUPPORTED_LOGGABLE_RESOURCES[context._Context][
401+
_SUPPORTED_LOGGABLE_RESOURCES[context.Context][
402402
metadata_context.schema_title
403403
]
404404
._query_experiment_row(metadata_context)
@@ -568,7 +568,7 @@ class _VertexResourceWithMetadata(NamedTuple):
568568
"""Represents a resource coupled with it's metadata representation"""
569569

570570
resource: base.VertexAiResourceNoun
571-
metadata: Union[artifact.Artifact, execution.Execution, context._Context]
571+
metadata: Union[artifact.Artifact, execution.Execution, context.Context]
572572

573573

574574
class _ExperimentLoggableSchema(NamedTuple):
@@ -581,7 +581,7 @@ class _ExperimentLoggableSchema(NamedTuple):
581581
"""
582582

583583
title: str
584-
type: Union[Type[context._Context], Type[execution.Execution]] = context._Context
584+
type: Union[Type[context.Context], Type[execution.Execution]] = context.Context
585585

586586

587587
class _ExperimentLoggable(abc.ABC):
@@ -618,7 +618,7 @@ class PipelineJob(..., experiment_loggable_schemas=
618618
_SUPPORTED_LOGGABLE_RESOURCES[schema.type][schema.title] = cls
619619

620620
@abc.abstractmethod
621-
def _get_context(self) -> context._Context:
621+
def _get_context(self) -> context.Context:
622622
"""Should return the metadata context that represents this resource.
623623
624624
The subclass should enforce this context exists.
@@ -631,7 +631,7 @@ def _get_context(self) -> context._Context:
631631
@classmethod
632632
@abc.abstractmethod
633633
def _query_experiment_row(
634-
cls, node: Union[context._Context, execution.Execution]
634+
cls, node: Union[context.Context, execution.Execution]
635635
) -> _ExperimentRow:
636636
"""Should return parameters and metrics for this resource as a run row.
637637
@@ -716,6 +716,6 @@ def _associate_to_experiment(self, experiment: Union[str, Experiment]):
716716
# Context -> 'system.ExperimentRun' -> aiplatform.ExperimentRun
717717
# Execution -> 'system.Run' -> aiplatform.ExperimentRun
718718
_SUPPORTED_LOGGABLE_RESOURCES: Dict[
719-
Union[Type[context._Context], Type[execution.Execution]],
719+
Union[Type[context.Context], Type[execution.Execution]],
720720
Dict[str, _ExperimentLoggable],
721-
] = {execution.Execution: dict(), context._Context: dict()}
721+
] = {execution.Execution: dict(), context.Context: dict()}

0 commit comments

Comments
 (0)