Skip to content

Commit f8aea02

Browse files
authored
chore: update Vertex Model artifact resolution to use versioned models and reduce allow expeirment run name length (#1468)
1 parent d07715a commit f8aea02

File tree

6 files changed

+145
-6
lines changed

6 files changed

+145
-6
lines changed

google/cloud/aiplatform/metadata/artifact.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,18 @@ def create_vertex_resource_artifact(cls, resource: Union[models.Model]) -> Artif
535535
"""
536536
cls.validate_resource_supports_metadata(resource)
537537
resource.wait()
538+
538539
metadata_type = cls._resource_to_artifact_type[type(resource)]
539540
uri = rest_utils.make_gcp_resource_rest_url(resource=resource)
540541

541542
return Artifact.create(
542543
schema_title=metadata_type,
543544
display_name=getattr(resource.gca_resource, "display_name", None),
544545
uri=uri,
545-
metadata={"resourceName": resource.resource_name},
546+
# Note that support for non-versioned resources requires
547+
# change to reference `resource_name` please update if
548+
# supporting resource other than Model
549+
metadata={"resourceName": resource.versioned_resource_name},
546550
project=resource.project,
547551
location=resource.location,
548552
credentials=resource.credentials,

google/cloud/aiplatform/metadata/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,6 @@
6565
schema_version="0.0.1",
6666
metadata={_VERTEX_EXPERIMENT_TRACKING_LABEL: True},
6767
)
68+
69+
_TB_RUN_ARTIFACT_POST_FIX_ID = "-tb-run"
70+
_EXPERIMENT_RUN_MAX_LENGTH = 128 - len(_TB_RUN_ARTIFACT_POST_FIX_ID)

google/cloud/aiplatform/metadata/experiment_run_resource.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,9 @@ def _validate_run_id(run_id: str):
607607
ValueError if run id is too long.
608608
"""
609609

610-
if len(run_id) > 128:
610+
if len(run_id) > constants._EXPERIMENT_RUN_MAX_LENGTH:
611611
raise ValueError(
612-
f"Length of Experiment ID and Run ID cannot be greater than 128. "
612+
f"Length of Experiment ID and Run ID cannot be greater than {constants._EXPERIMENT_RUN_MAX_LENGTH}. "
613613
f"{run_id} is of length {len(run_id)}"
614614
)
615615

@@ -822,7 +822,7 @@ def _tensorboard_run_id(run_id: str) -> str:
822822
Returns:
823823
Resource id for the associated tensorboard run artifact.
824824
"""
825-
return f"{run_id}-tb-run"
825+
return f"{run_id}{constants._TB_RUN_ARTIFACT_POST_FIX_ID}"
826826

827827
@_v1_not_supported
828828
def assign_backing_tensorboard(

google/cloud/aiplatform/utils/rest_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def make_gcp_resource_rest_url(resource: base.VertexAiResourceNoun) -> str:
2525
Returns:
2626
The formatted url of resource.
2727
"""
28-
resource_name = resource.resource_name
28+
try:
29+
resource_name = resource.versioned_resource_name
30+
except AttributeError:
31+
resource_name = resource.resource_name
2932
version = resource.api_client._default_version
3033
api_uri = resource.api_client.api_endpoint
3134

tests/unit/aiplatform/test_metadata.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,6 @@ def test_init_experiment_wrong_schema(self):
993993
)
994994

995995
@pytest.mark.usefixtures("get_metadata_store_mock")
996-
@pytest.mark.usefixtures()
997996
def test_start_run(
998997
self,
999998
get_experiment_mock,
@@ -1025,6 +1024,25 @@ def test_start_run(
10251024
context=_EXPERIMENT_MOCK.name, child_contexts=[_EXPERIMENT_RUN_MOCK.name]
10261025
)
10271026

1027+
@pytest.mark.usefixtures("get_metadata_store_mock", "get_experiment_mock")
1028+
def test_start_run_fails_when_run_name_too_long(self):
1029+
1030+
aiplatform.init(
1031+
project=_TEST_PROJECT,
1032+
location=_TEST_LOCATION,
1033+
experiment=_TEST_EXPERIMENT,
1034+
)
1035+
1036+
run_name_too_long = "".join(
1037+
"a"
1038+
for _ in range(
1039+
constants._EXPERIMENT_RUN_MAX_LENGTH + 2 - len(_TEST_EXPERIMENT)
1040+
)
1041+
)
1042+
1043+
with pytest.raises(ValueError):
1044+
aiplatform.start_run(run_name_too_long)
1045+
10281046
@pytest.mark.usefixtures(
10291047
"get_metadata_store_mock",
10301048
"get_experiment_mock",

tests/unit/aiplatform/test_metadata_resources.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud.aiplatform.metadata import artifact
2929
from google.cloud.aiplatform.metadata import context
3030
from google.cloud.aiplatform.metadata import execution
31+
from google.cloud.aiplatform.metadata import utils as metadata_utils
3132
from google.cloud.aiplatform_v1 import (
3233
MetadataServiceClient,
3334
AddExecutionEventsResponse,
@@ -39,6 +40,8 @@
3940
AddContextArtifactsAndExecutionsResponse,
4041
)
4142

43+
import test_models
44+
4245
# project
4346
_TEST_PROJECT = "test-project"
4447
_TEST_LOCATION = "us-central1"
@@ -543,6 +546,34 @@ def test_add_executions_only(self, add_context_artifacts_and_executions_mock):
543546
)
544547

545548

549+
get_model_with_version_mock = test_models.get_model_with_version
550+
_VERTEX_MODEL_ARTIFACT_URI = f"https://{_TEST_LOCATION}-aiplatform.googleapis.com/v1/{test_models._TEST_MODEL_OBJ_WITH_VERSION.name}"
551+
552+
553+
@pytest.fixture
554+
def list_vertex_model_artifact_mock():
555+
with patch.object(MetadataServiceClient, "list_artifacts") as list_artifacts_mock:
556+
list_artifacts_mock.return_value = [
557+
GapicArtifact(
558+
name=_TEST_ARTIFACT_NAME,
559+
uri=_VERTEX_MODEL_ARTIFACT_URI,
560+
display_name=_TEST_DISPLAY_NAME,
561+
schema_title=_TEST_SCHEMA_TITLE,
562+
schema_version=_TEST_SCHEMA_VERSION,
563+
description=_TEST_DESCRIPTION,
564+
metadata=_TEST_METADATA,
565+
)
566+
]
567+
yield list_artifacts_mock
568+
569+
570+
@pytest.fixture
571+
def list_artifact_empty_mock():
572+
with patch.object(MetadataServiceClient, "list_artifacts") as list_artifacts_mock:
573+
list_artifacts_mock.return_value = []
574+
yield list_artifacts_mock
575+
576+
546577
class TestExecution:
547578
def setup_method(self):
548579
reload(initializer)
@@ -680,6 +711,86 @@ def test_add_artifact(self, add_execution_events_mock):
680711
events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)],
681712
)
682713

714+
@pytest.mark.usefixtures("get_execution_mock", "get_model_with_version_mock")
715+
def test_add_vertex_model(
716+
self, add_execution_events_mock, list_vertex_model_artifact_mock
717+
):
718+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
719+
720+
my_execution = execution.Execution.get_or_create(
721+
resource_id=_TEST_EXECUTION_ID,
722+
schema_title=_TEST_SCHEMA_TITLE,
723+
display_name=_TEST_DISPLAY_NAME,
724+
schema_version=_TEST_SCHEMA_VERSION,
725+
description=_TEST_DESCRIPTION,
726+
metadata=_TEST_METADATA,
727+
metadata_store_id=_TEST_METADATA_STORE,
728+
)
729+
730+
my_model = aiplatform.Model(test_models._TEST_MODEL_NAME)
731+
my_execution.assign_output_artifacts(artifacts=[my_model])
732+
733+
list_vertex_model_artifact_mock.assert_called_once_with(
734+
request=dict(
735+
parent="projects/test-project/locations/us-central1/metadataStores/default",
736+
filter=metadata_utils._make_filter_string(
737+
schema_title="google.VertexModel", uri=_VERTEX_MODEL_ARTIFACT_URI
738+
),
739+
)
740+
)
741+
742+
add_execution_events_mock.assert_called_once_with(
743+
execution=_TEST_EXECUTION_NAME,
744+
events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)],
745+
)
746+
747+
@pytest.mark.usefixtures("get_execution_mock", "get_model_with_version_mock")
748+
def test_add_vertex_model_not_resolved(
749+
self, add_execution_events_mock, list_artifact_empty_mock, create_artifact_mock
750+
):
751+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
752+
753+
my_execution = execution.Execution.get_or_create(
754+
resource_id=_TEST_EXECUTION_ID,
755+
schema_title=_TEST_SCHEMA_TITLE,
756+
display_name=_TEST_DISPLAY_NAME,
757+
schema_version=_TEST_SCHEMA_VERSION,
758+
description=_TEST_DESCRIPTION,
759+
metadata=_TEST_METADATA,
760+
metadata_store_id=_TEST_METADATA_STORE,
761+
)
762+
763+
my_model = aiplatform.Model(test_models._TEST_MODEL_NAME)
764+
my_execution.assign_output_artifacts(artifacts=[my_model])
765+
766+
list_artifact_empty_mock.assert_called_once_with(
767+
request=dict(
768+
parent="projects/test-project/locations/us-central1/metadataStores/default",
769+
filter=metadata_utils._make_filter_string(
770+
schema_title="google.VertexModel", uri=_VERTEX_MODEL_ARTIFACT_URI
771+
),
772+
)
773+
)
774+
775+
expected_artifact = GapicArtifact(
776+
schema_title="google.VertexModel",
777+
display_name=test_models._TEST_MODEL_OBJ_WITH_VERSION.display_name,
778+
uri=_VERTEX_MODEL_ARTIFACT_URI,
779+
metadata={"resourceName": test_models._TEST_MODEL_OBJ_WITH_VERSION.name},
780+
state=GapicArtifact.State.LIVE,
781+
)
782+
783+
create_artifact_mock.assert_called_once_with(
784+
parent="projects/test-project/locations/us-central1/metadataStores/default",
785+
artifact=expected_artifact,
786+
artifact_id=None,
787+
)
788+
789+
add_execution_events_mock.assert_called_once_with(
790+
execution=_TEST_EXECUTION_NAME,
791+
events=[Event(artifact=_TEST_ARTIFACT_NAME, type_=Event.Type.OUTPUT)],
792+
)
793+
683794
@pytest.mark.usefixtures("get_execution_mock")
684795
def test_query_input_and_output_artifacts(
685796
self, query_execution_inputs_and_outputs_mock

0 commit comments

Comments
 (0)