Skip to content

Commit 181dc7a

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Add display model evaluation button for Ipython environments
PiperOrigin-RevId: 617902950
1 parent 13ec7e0 commit 181dc7a

File tree

4 files changed

+60
-9
lines changed

4 files changed

+60
-9
lines changed

google/cloud/aiplatform/model_evaluation/model_evaluation_job.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Optional, List, Union
1919

2020
from google.auth import credentials as auth_credentials
21+
import grpc
2122

2223
from google.cloud import aiplatform
2324
from google.cloud.aiplatform import base
@@ -27,6 +28,7 @@
2728
)
2829
from google.cloud.aiplatform import model_evaluation
2930
from google.cloud.aiplatform import pipeline_jobs
31+
from google.cloud.aiplatform.utils import _ipython_utils
3032

3133
from google.cloud.aiplatform.compat.types import (
3234
pipeline_state_v1 as gca_pipeline_state_v1,
@@ -380,7 +382,6 @@ def get_model_evaluation(
380382
return
381383

382384
for component in self.backing_pipeline_job.task_details:
383-
384385
# This assumes that task_details has a task with a task_name == backing_pipeline_job.name
385386
if not component.task_name == self.backing_pipeline_job.name:
386387
continue
@@ -407,5 +408,14 @@ def get_model_evaluation(
407408
evaluation_name=eval_resource_name,
408409
credentials=self.credentials,
409410
)
410-
411+
_ipython_utils.display_model_evaluation_button(eval_resource)
411412
return eval_resource
413+
414+
def wait(self) -> None:
415+
"""Wait for the PipelineJob to complete, then get the model evaluation resource."""
416+
super().wait()
417+
418+
try:
419+
self.get_model_evaluation()
420+
except grpc.RpcError as e:
421+
_LOGGER.error("Get model evaluation call failed with error %s", e)

google/cloud/aiplatform/models.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from google.cloud.aiplatform import utils
5050
from google.cloud.aiplatform.utils import gcs_utils
5151
from google.cloud.aiplatform.utils import _explanation_utils
52+
from google.cloud.aiplatform.utils import _ipython_utils
5253
from google.cloud.aiplatform import model_evaluation
5354
from google.cloud.aiplatform.compat.services import endpoint_service_client
5455

@@ -5136,7 +5137,8 @@ def get_model_evaluation(
51365137
_LOGGER.warning(
51375138
f"Your model has more than one model evaluation, this is returning only one evaluation resource: {evaluations[0].resource_name}"
51385139
)
5139-
return evaluations[0] if evaluations else evaluations
5140+
_ipython_utils.display_model_evaluation_button(evaluations[0])
5141+
return evaluations[0]
51405142
else:
51415143
resource_uri_parts = self._parse_resource_name(self.resource_name)
51425144
evaluation_resource_name = (
@@ -5146,10 +5148,12 @@ def get_model_evaluation(
51465148
)
51475149
)
51485150

5149-
return model_evaluation.ModelEvaluation(
5151+
evaluation = model_evaluation.ModelEvaluation(
51505152
evaluation_name=evaluation_resource_name,
51515153
credentials=self.credentials,
51525154
)
5155+
_ipython_utils.display_model_evaluation_button(evaluation)
5156+
return evaluation
51535157

51545158
def evaluate(
51555159
self,

google/cloud/aiplatform/utils/_ipython_utils.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
#
1717

1818
import sys
19+
import typing
1920
from uuid import uuid4
2021
from typing import Optional
2122

2223
from google.cloud.aiplatform import base
23-
from google.cloud.aiplatform.metadata import experiment_resources
24+
25+
if typing.TYPE_CHECKING:
26+
from google.cloud.aiplatform.metadata import experiment_resources
27+
from google.cloud.aiplatform import model_evaluation
2428

2529
_LOGGER = base.Logger(__name__)
2630

@@ -142,7 +146,7 @@ def display_link(text: str, url: str, icon: Optional[str] = "open_in_new") -> No
142146
display(HTML(html))
143147

144148

145-
def display_experiment_button(experiment: experiment_resources.Experiment) -> None:
149+
def display_experiment_button(experiment: "experiment_resources.Experiment") -> None:
146150
"""Function to generate a link bound to the Vertex experiment"""
147151
if not is_ipython_available():
148152
return
@@ -162,3 +166,34 @@ def display_experiment_button(experiment: experiment_resources.Experiment) -> No
162166
+ f"runs?project={project}"
163167
)
164168
display_link("View Experiment", uri, "science")
169+
170+
171+
def display_model_evaluation_button(
172+
evaluation: "model_evaluation.ModelEvaluation",
173+
) -> None:
174+
"""Function to generate a link bound to the Vertex model evaluation"""
175+
if not is_ipython_available():
176+
return
177+
178+
try:
179+
resource_name = evaluation.resource_name
180+
fields = evaluation._parse_resource_name(resource_name)
181+
project = fields["project"]
182+
location = fields["location"]
183+
model_id = fields["model"]
184+
evaluation_id = fields["evaluation"]
185+
except AttributeError:
186+
_LOGGER.warning("Unable to parse model evaluation metadata")
187+
return
188+
189+
if "@" in model_id:
190+
model_id, version_id = model_id.split("@")
191+
else:
192+
version_id = "default"
193+
194+
uri = (
195+
"https://console.cloud.google.com/vertex-ai/models/locations/"
196+
+ f"{location}/models/{model_id}/versions/{version_id}/evaluations/"
197+
+ f"{evaluation_id}?project={project}"
198+
)
199+
display_link("View Model Evaluation", uri, "model_training")

tests/unit/aiplatform/test_model_evaluation.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,9 @@ def mock_pipeline_service_get():
553553
make_pipeline_job(
554554
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
555555
),
556+
make_pipeline_job(
557+
gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
558+
),
556559
]
557560

558561
yield mock_get_pipeline_job
@@ -797,7 +800,6 @@ def test_get_model_evaluation_metrics(self, mock_model_eval_get):
797800
assert eval_metrics == _TEST_MODEL_EVAL_METRICS
798801

799802
def test_no_delete_model_evaluation_method(self, mock_model_eval_get):
800-
801803
my_eval = aiplatform.ModelEvaluation(
802804
evaluation_name=_TEST_MODEL_EVAL_RESOURCE_NAME
803805
)
@@ -1028,6 +1030,7 @@ def test_model_evaluation_job_submit(
10281030
mock_load_yaml_and_json,
10291031
mock_model,
10301032
get_model_mock,
1033+
mock_model_eval_get,
10311034
mock_model_eval_job_get,
10321035
mock_pipeline_service_get,
10331036
mock_model_eval_job_create,
@@ -1128,6 +1131,7 @@ def test_model_evaluation_job_submit_with_experiment(
11281131
mock_model,
11291132
get_model_mock,
11301133
get_experiment_mock,
1134+
mock_model_eval_get,
11311135
mock_model_eval_job_get,
11321136
mock_pipeline_service_get,
11331137
mock_model_eval_job_create,
@@ -1308,7 +1312,6 @@ def test_model_evaluation_job_get_model_evaluation_with_failed_pipeline_run_rais
13081312
mock_pipeline_bucket_exists,
13091313
mock_request_urlopen,
13101314
):
1311-
13121315
aiplatform.init(
13131316
project=_TEST_PROJECT,
13141317
location=_TEST_LOCATION,
@@ -1388,7 +1391,6 @@ def test_model_evaluation_job_get_model_evaluation_with_pending_pipeline_run_ret
13881391
def test_get_template_url(
13891392
self,
13901393
):
1391-
13921394
template_url = model_evaluation_job._ModelEvaluationJob._get_template_url(
13931395
model_type="automl_tabular",
13941396
feature_attributions=False,

0 commit comments

Comments
 (0)