Skip to content

Commit 94b2f29

Browse files
jaycee-licopybara-github
authored andcommitted
feat: add MLMD schema class ExperimentModel
PiperOrigin-RevId: 501468901
1 parent 6fa93a4 commit 94b2f29

File tree

4 files changed

+408
-8
lines changed

4 files changed

+408
-8
lines changed

google/cloud/aiplatform/metadata/schema/base_artifact.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ def _init_with_resource_name(
107107
self,
108108
*,
109109
artifact_name: str,
110+
metadata_store_id: str = "default",
111+
project: Optional[str] = None,
112+
location: Optional[str] = None,
113+
credentials: Optional[auth_credentials.Credentials] = None,
110114
):
111115

112116
"""Initializes the Artifact instance using an existing resource.
@@ -115,13 +119,31 @@ def _init_with_resource_name(
115119
artifact_name (str):
116120
Artifact name with the following format, this is globally unique in a metadataStore:
117121
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
122+
metadata_store_id (str):
123+
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
124+
If artifact_name is a fully-qualified resource, its metadata_store_id overrides this one.
125+
project (str):
126+
Optional. Project to retrieve the artifact from. If not set, project
127+
set in aiplatform.init will be used.
128+
location (str):
129+
Optional. Location to retrieve the Artifact from. If not set, location
130+
set in aiplatform.init will be used.
131+
credentials (auth_credentials.Credentials):
132+
Optional. Custom credentials to use to retrieve this Artifact. Overrides
133+
credentials set in aiplatform.init.
118134
"""
119135
# Add User Agent Header for metrics tracking if one is not specified
120136
# If one is already specified this call was initiated by a sub class.
121137
if not base_constants.USER_AGENT_SDK_COMMAND:
122138
base_constants.USER_AGENT_SDK_COMMAND = "aiplatform.metadata.schema.base_artifact.BaseArtifactSchema._init_with_resource_name"
123139

124-
super(BaseArtifactSchema, self).__init__(artifact_name=artifact_name)
140+
super(BaseArtifactSchema, self).__init__(
141+
artifact_name=artifact_name,
142+
metadata_store_id=metadata_store_id,
143+
project=project,
144+
location=location,
145+
credentials=credentials,
146+
)
125147

126148
def create(
127149
self,

google/cloud/aiplatform/metadata/schema/google/artifact_schema.py

+156-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818
from typing import Optional, Dict, List
1919

20+
from google.auth import credentials as auth_credentials
2021
from google.cloud.aiplatform.compat.types import artifact as gca_artifact
2122
from google.cloud.aiplatform.metadata.schema import base_artifact
2223
from google.cloud.aiplatform.metadata.schema import utils
@@ -359,7 +360,6 @@ def __init__(
359360
extended_metadata = copy.deepcopy(metadata) if metadata else {}
360361
if aggregation_type:
361362
if aggregation_type not in _CLASSIFICATION_METRICS_AGGREGATION_TYPE:
362-
## Todo: add negative test case for this
363363
raise ValueError(
364364
"aggregation_type can only be 'AGGREGATION_TYPE_UNSPECIFIED', 'MACRO_AVERAGE', or 'MICRO_AVERAGE'."
365365
)
@@ -583,3 +583,158 @@ def __init__(
583583
metadata=extended_metadata,
584584
state=state,
585585
)
586+
587+
588+
class ExperimentModel(base_artifact.BaseArtifactSchema):
589+
"""An artifact representing a Vertex Experiment Model."""
590+
591+
schema_title = "google.ExperimentModel"
592+
593+
RESERVED_METADATA_KEYS = [
594+
"frameworkName",
595+
"frameworkVersion",
596+
"modelFile",
597+
"modelClass",
598+
"predictSchemata",
599+
]
600+
601+
def __init__(
602+
self,
603+
*,
604+
framework_name: str,
605+
framework_version: str,
606+
model_file: str,
607+
uri: str,
608+
model_class: Optional[str] = None,
609+
predict_schemata: Optional[utils.PredictSchemata] = None,
610+
artifact_id: Optional[str] = None,
611+
display_name: Optional[str] = None,
612+
schema_version: Optional[str] = None,
613+
description: Optional[str] = None,
614+
metadata: Optional[Dict] = None,
615+
state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE,
616+
):
617+
"""Args:
618+
framework_name (str):
619+
Required. The name of the model's framework. E.g., 'sklearn'
620+
framework_version (str):
621+
Required. The version of the model's framework. E.g., '1.1.0'
622+
model_file (str):
623+
Required. The file name of the model. E.g., 'model.pkl'
624+
uri (str):
625+
Required. The uniform resource identifier of the model artifact directory.
626+
model_class (str):
627+
Optional. The class name of the model. E.g., 'sklearn.linear_model._base.LinearRegression'
628+
predict_schemata (PredictSchemata):
629+
Optional. An instance of PredictSchemata which holds instance, parameter and prediction schema uris.
630+
artifact_id (str):
631+
Optional. The <resource_id> portion of the Artifact name with
632+
the format. This is globally unique in a metadataStore:
633+
projects/123/locations/us-central1/metadataStores/<metadata_store_id>/artifacts/<resource_id>.
634+
display_name (str):
635+
Optional. The user-defined name of the Artifact.
636+
schema_version (str):
637+
Optional. schema_version specifies the version used by the Artifact.
638+
If not set, defaults to use the latest version.
639+
description (str):
640+
Optional. Describes the purpose of the Artifact to be created.
641+
metadata (Dict):
642+
Optional. Contains the metadata information that will be stored in the Artifact.
643+
state (google.cloud.gapic.types.Artifact.State):
644+
Optional. The state of this Artifact. This is a
645+
property of the Artifact, and does not imply or
646+
apture any ongoing process. This property is
647+
managed by clients (such as Vertex AI
648+
Pipelines), and the system does not prescribe or
649+
check the validity of state transitions.
650+
"""
651+
if metadata:
652+
for k in metadata:
653+
if k in self.RESERVED_METADATA_KEYS:
654+
raise ValueError(f"'{k}' is a system reserved key in metadata.")
655+
extended_metadata = copy.deepcopy(metadata)
656+
else:
657+
extended_metadata = {}
658+
extended_metadata["frameworkName"] = framework_name
659+
extended_metadata["frameworkVersion"] = framework_version
660+
extended_metadata["modelFile"] = model_file
661+
if model_class is not None:
662+
extended_metadata["modelClass"] = model_class
663+
if predict_schemata is not None:
664+
extended_metadata["predictSchemata"] = predict_schemata.to_dict()
665+
666+
super().__init__(
667+
uri=uri,
668+
artifact_id=artifact_id,
669+
display_name=display_name,
670+
schema_version=schema_version,
671+
description=description,
672+
metadata=extended_metadata,
673+
state=state,
674+
)
675+
676+
@classmethod
677+
def get(
678+
cls,
679+
artifact_id: str,
680+
*,
681+
metadata_store_id: str = "default",
682+
project: Optional[str] = None,
683+
location: Optional[str] = None,
684+
credentials: Optional[auth_credentials.Credentials] = None,
685+
) -> "ExperimentModel":
686+
"""Retrieves an existing ExperimentModel artifact given an artifact id.
687+
688+
Args:
689+
artifact_id (str):
690+
Required. An artifact id of the ExperimentModel artifact.
691+
metadata_store_id (str):
692+
Optional. MetadataStore to retrieve Artifact from. If not set, metadata_store_id is set to "default".
693+
If artifact_id is a fully-qualified resource name, its metadata_store_id overrides this one.
694+
project (str):
695+
Optional. Project to retrieve the artifact from. If not set, project
696+
set in aiplatform.init will be used.
697+
location (str):
698+
Optional. Location to retrieve the Artifact from. If not set, location
699+
set in aiplatform.init will be used.
700+
credentials (auth_credentials.Credentials):
701+
Optional. Custom credentials to use to retrieve this Artifact. Overrides
702+
credentials set in aiplatform.init.
703+
704+
Returns:
705+
An ExperimentModel class that represents an Artifact resource.
706+
707+
Raises:
708+
ValueError: if artifact's schema title is not 'google.ExperimentModel'.
709+
"""
710+
experiment_model = ExperimentModel(
711+
framework_name="",
712+
framework_version="",
713+
model_file="",
714+
uri="",
715+
)
716+
experiment_model._init_with_resource_name(
717+
artifact_name=artifact_id,
718+
metadata_store_id=metadata_store_id,
719+
project=project,
720+
location=location,
721+
credentials=credentials,
722+
)
723+
if experiment_model.schema_title != cls.schema_title:
724+
raise ValueError(
725+
f"The schema title of the artifact must be {cls.schema_title}."
726+
f"Got {experiment_model.schema_title}."
727+
)
728+
return experiment_model
729+
730+
@property
731+
def framework_name(self) -> Optional[str]:
732+
return self.metadata.get("frameworkName")
733+
734+
@property
735+
def framework_version(self) -> Optional[str]:
736+
return self.metadata.get("frameworkVersion")
737+
738+
@property
739+
def model_class(self) -> Optional[str]:
740+
return self.metadata.get("modelClass")

google/cloud/aiplatform/metadata/schema/utils.py

+42-6
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ class PredictSchemata:
5050
prediction_schema_uri: str
5151

5252
def to_dict(self):
53-
"""ML metadata schema dictionary representation of this DataClass"""
53+
"""ML metadata schema dictionary representation of this DataClass.
54+
55+
56+
Returns:
57+
A dictionary that represents the PredictSchemata class.
58+
"""
5459
results = {}
5560
results["instanceSchemaUri"] = self.instance_schema_uri
5661
results["parametersSchemaUri"] = self.parameters_schema_uri
@@ -62,6 +67,7 @@ def to_dict(self):
6267
@dataclass
6368
class ContainerSpec:
6469
"""Container configuration for the model.
70+
6571
Args:
6672
image_uri (str):
6773
Required. URI of the Docker image to be used as the custom
@@ -124,7 +130,12 @@ class ContainerSpec:
124130
health_route: Optional[str] = None
125131

126132
def to_dict(self):
127-
"""ML metadata schema dictionary representation of this DataClass"""
133+
"""ML metadata schema dictionary representation of this DataClass.
134+
135+
136+
Returns:
137+
A dictionary that represents the ContainerSpec class.
138+
"""
128139
results = {}
129140
results["imageUri"] = self.image_uri
130141
if self.command:
@@ -146,6 +157,7 @@ def to_dict(self):
146157
@dataclass
147158
class AnnotationSpec:
148159
"""A class that represents the annotation spec of a Confusion Matrix.
160+
149161
Args:
150162
display_name (str):
151163
Optional. Display name for a column of a confusion matrix.
@@ -157,7 +169,12 @@ class AnnotationSpec:
157169
id: Optional[str] = None
158170

159171
def to_dict(self):
160-
"""ML metadata schema dictionary representation of this DataClass"""
172+
"""ML metadata schema dictionary representation of this DataClass.
173+
174+
175+
Returns:
176+
A dictionary that represents the AnnotationSpec class.
177+
"""
161178
results = {}
162179
if self.display_name:
163180
results["displayName"] = self.display_name
@@ -170,6 +187,7 @@ def to_dict(self):
170187
@dataclass
171188
class ConfusionMatrix:
172189
"""A class that represents a Confusion Matrix.
190+
173191
Args:
174192
matrix (List[List[int]]):
175193
Required. A 2D array of integers that represets the values for the confusion matrix.
@@ -181,10 +199,23 @@ class ConfusionMatrix:
181199
annotation_specs: Optional[List[AnnotationSpec]] = None
182200

183201
def to_dict(self):
184-
## Todo: add a validation to check 'matrix' and 'annotation_specs' have the same length
185-
"""ML metadata schema dictionary representation of this DataClass"""
202+
"""ML metadata schema dictionary representation of this DataClass.
203+
204+
Returns:
205+
A dictionary that represents the ConfusionMatrix class.
206+
207+
Raises:
208+
ValueError: if annotation_specs and matrix have different length.
209+
"""
186210
results = {}
187211
if self.annotation_specs:
212+
if len(self.annotation_specs) != len(self.matrix):
213+
raise ValueError(
214+
"Length of annotation_specs and matrix must be the same. "
215+
"Got lengths {} and {} respectively.".format(
216+
len(self.annotation_specs), len(self.matrix)
217+
)
218+
)
188219
results["annotationSpecs"] = [
189220
annotation_spec.to_dict() for annotation_spec in self.annotation_specs
190221
]
@@ -255,7 +286,12 @@ class ConfidenceMetric:
255286
confusion_matrix: Optional[ConfusionMatrix] = None
256287

257288
def to_dict(self):
258-
"""ML metadata schema dictionary representation of this DataClass"""
289+
"""ML metadata schema dictionary representation of this DataClass.
290+
291+
292+
Returns:
293+
A dictionary that represents the ConfidenceMetric class.
294+
"""
259295
results = {}
260296
results["confidenceThreshold"] = self.confidence_threshold
261297
if self.recall is not None:

0 commit comments

Comments
 (0)