Skip to content

Commit d40e612

Browse files
committed
switch to use ModelContainerSpec and PredictSchemata instead of custom_dataclasses
1 parent 49e34cc commit d40e612

File tree

4 files changed

+39
-180
lines changed

4 files changed

+39
-180
lines changed

google/cloud/aiplatform/metadata/types/google_types.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional, Dict, NamedTuple, List
1818
from dataclasses import dataclass
1919
from google.cloud.aiplatform.metadata.types import base
20-
from google.cloud.aiplatform.metadata.types import utils
20+
from google.cloud.aiplatform.compat.types import model_v1 as model
2121

2222

2323
class VertexDataset(base.BaseArtifactSchema):
@@ -177,8 +177,8 @@ class UnmanagedContainerModel(base.BaseArtifactSchema):
177177

178178
def __init__(
179179
self,
180-
predict_schema_ta: utils.PredictSchemata,
181-
container_spec: utils.PredictSchemata,
180+
predict_schema_ta: model.PredictSchemata,
181+
container_spec: model.ModelContainerSpec,
182182
unmanaged_container_model_name: Optional[str] = None,
183183
uri: Optional[str] = None,
184184
display_name: Optional[str] = None,
@@ -213,8 +213,35 @@ def __init__(
213213
"""
214214
extended_metadata = metadata or {}
215215
extended_metadata["resourceName"] = unmanaged_container_model_name
216-
extended_metadata["predictSchemata"] = predict_schema_ta.to_dict()
217-
extended_metadata["containerSpec"] = container_spec.to_dict()
216+
extended_metadata["predictSchemata"] = {}
217+
extended_metadata["predictSchemata"][
218+
"instanceSchemaUri"
219+
] = predict_schema_ta.instance_schema_uri
220+
extended_metadata["predictSchemata"][
221+
"parametersSchemaUri"
222+
] = predict_schema_ta.parameters_schema_uri
223+
extended_metadata["predictSchemata"][
224+
"predictionSchemaUri"
225+
] = predict_schema_ta.prediction_schema_uri
226+
227+
extended_metadata["containerSpec"] = {}
228+
extended_metadata["containerSpec"]["imageUri"] = container_spec.image_uri
229+
if container_spec.command:
230+
extended_metadata["containerSpec"]["command"] = container_spec.command
231+
if container_spec.args:
232+
extended_metadata["containerSpec"]["args"] = container_spec.args
233+
if container_spec.env:
234+
extended_metadata["containerSpec"]["env"] = container_spec.env
235+
if container_spec.ports:
236+
extended_metadata["containerSpec"]["ports"] = container_spec.ports
237+
if container_spec.predict_route:
238+
extended_metadata["containerSpec"][
239+
"predictRoute"
240+
] = container_spec.predict_route
241+
if container_spec.health_route:
242+
extended_metadata["containerSpec"][
243+
"healthRoute"
244+
] = container_spec.health_route
218245

219246
super(UnmanagedContainerModel, self).__init__(
220247
schema_title=self.SCHEMA_TITLE,

google/cloud/aiplatform/metadata/types/system_types.py

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#
1717
from typing import Optional, Dict, List
1818
from google.cloud.aiplatform.metadata.types import base
19-
from google.cloud.aiplatform.metadata.types import utils
2019
from itertools import zip_longest
2120

2221

google/cloud/aiplatform/metadata/types/utils.py

-97
This file was deleted.

tests/unit/aiplatform/test_metadata_schema_types.py

+7-77
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from google.cloud.aiplatform.metadata.types import base
2828
from google.cloud.aiplatform.metadata.types import google_types
2929
from google.cloud.aiplatform.metadata.types import system_types
30-
from google.cloud.aiplatform.metadata.types import utils
3130

31+
from google.cloud.aiplatform.compat.types import model_v1 as model
3232
from google.cloud.aiplatform_v1 import MetadataServiceClient
3333
from google.cloud.aiplatform_v1 import Artifact as GapicArtifact
3434

@@ -212,13 +212,13 @@ def test_vertex_endpoint_constructor_parameters_are_set_correctly(self):
212212
assert artifact.schema_version == _TEST_SCHEMA_VERSION
213213

214214
def test_unmanaged_container_model_title_is_set_correctly(self):
215-
predict_schema_ta = utils.PredictSchemata(
215+
predict_schema_ta = model.PredictSchemata(
216216
instance_schema_uri="instance_uri",
217217
prediction_schema_uri="prediction_uri",
218218
parameters_schema_uri="parameters_uri",
219219
)
220220

221-
container_spec = utils.ContainerSpec(
221+
container_spec = model.ModelContainerSpec(
222222
image_uri="gcr.io/test_container_image_uri"
223223
)
224224
artifact = google_types.UnmanagedContainerModel(
@@ -228,13 +228,13 @@ def test_unmanaged_container_model_title_is_set_correctly(self):
228228
assert artifact.schema_title == "google.UnmanagedContainerModel"
229229

230230
def test_unmanaged_container_model_resouce_name_is_set_in_metadata(self):
231-
predict_schema_ta = utils.PredictSchemata(
231+
predict_schema_ta = model.PredictSchemata(
232232
instance_schema_uri="instance_uri",
233233
prediction_schema_uri="prediction_uri",
234234
parameters_schema_uri="parameters_uri",
235235
)
236236

237-
container_spec = utils.ContainerSpec(
237+
container_spec = model.ModelContainerSpec(
238238
image_uri="gcr.io/test_container_image_uri"
239239
)
240240
artifact = google_types.UnmanagedContainerModel(
@@ -245,13 +245,13 @@ def test_unmanaged_container_model_resouce_name_is_set_in_metadata(self):
245245
assert artifact.metadata["resourceName"] == _TEST_ARTIFACT_NAME
246246

247247
def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self):
248-
predict_schema_ta = utils.PredictSchemata(
248+
predict_schema_ta = model.PredictSchemata(
249249
instance_schema_uri="instance_uri",
250250
prediction_schema_uri="prediction_uri",
251251
parameters_schema_uri="parameters_uri",
252252
)
253253

254-
container_spec = utils.ContainerSpec(
254+
container_spec = model.ModelContainerSpec(
255255
image_uri="gcr.io/test_container_image_uri"
256256
)
257257

@@ -360,73 +360,3 @@ def test_system_metrics_constructor_parameters_are_set_correctly(self):
360360
assert artifact.metadata["f1score"] == 0.4
361361
assert artifact.metadata["mean_absolute_error"] == 0.5
362362
assert artifact.metadata["mean_squared_error"] == 0.6
363-
364-
365-
class TestMetadataUtils:
366-
def setup_method(self):
367-
reload(initializer)
368-
reload(metadata)
369-
reload(aiplatform)
370-
371-
def teardown_method(self):
372-
initializer.global_pool.shutdown(wait=True)
373-
374-
def test_predict_schemata_to_dict_method_returns_correct_schema(self):
375-
predict_schema_ta = utils.PredictSchemata(
376-
instance_schema_uri="instance_uri",
377-
prediction_schema_uri="prediction_uri",
378-
parameters_schema_uri="parameters_uri",
379-
)
380-
expected_results = {
381-
"instanceSchemaUri": "instance_uri",
382-
"parametersSchemaUri": "parameters_uri",
383-
"predictionSchemaUri": "prediction_uri",
384-
}
385-
386-
assert json.dumps(predict_schema_ta.to_dict()) == json.dumps(expected_results)
387-
388-
def test_container_spec_to_dict_method_returns_correct_schema(self):
389-
container_spec = utils.ContainerSpec(
390-
image_uri="gcr.io/some_container_image_uri",
391-
command=["test_command"],
392-
args=["test_args"],
393-
env=[{"env_var_name": "env_var_value"}],
394-
ports=[1],
395-
predict_route="test_prediction_rout",
396-
health_route="test_health_rout",
397-
)
398-
399-
expected_results = {
400-
"imageUri": "gcr.io/some_container_image_uri",
401-
"command": ["test_command"],
402-
"args": ["test_args"],
403-
"env": [{"env_var_name": "env_var_value"}],
404-
"ports": [1],
405-
"predictRoute": "test_prediction_rout",
406-
"healthRoute": "test_health_rout",
407-
}
408-
409-
assert json.dumps(container_spec.to_dict()) == json.dumps(expected_results)
410-
411-
def test_container_spec_to_dict_method_returns_correct_schema(self):
412-
container_spec = utils.ContainerSpec(
413-
image_uri="gcr.io/some_container_image_uri",
414-
command=["test_command"],
415-
args=["test_args"],
416-
env=[{"env_var_name": "env_var_value"}],
417-
ports=[1],
418-
predict_route="test_prediction_rout",
419-
health_route="test_health_rout",
420-
)
421-
422-
expected_results = {
423-
"imageUri": "gcr.io/some_container_image_uri",
424-
"command": ["test_command"],
425-
"args": ["test_args"],
426-
"env": [{"env_var_name": "env_var_value"}],
427-
"ports": [1],
428-
"predictRoute": "test_prediction_rout",
429-
"healthRoute": "test_health_rout",
430-
}
431-
432-
assert json.dumps(container_spec.to_dict()) == json.dumps(expected_results)

0 commit comments

Comments
 (0)