Skip to content

Commit b1e9a6c

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add display tuning job button for Ipython environments when starting a new job
PiperOrigin-RevId: 646484448
1 parent 78a92a1 commit b1e9a6c

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

google/cloud/aiplatform/utils/_ipython_utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.cloud.aiplatform.metadata import experiment_resources
2727
from google.cloud.aiplatform.metadata import experiment_run_resource
2828
from google.cloud.aiplatform import model_evaluation
29+
from vertexai.preview.tuning import sft
2930

3031
_LOGGER = base.Logger(__name__)
3132

@@ -228,3 +229,26 @@ def display_model_evaluation_button(
228229
+ f"{evaluation_id}?project={project}"
229230
)
230231
display_link("View Model Evaluation", uri, "lightbulb")
232+
233+
234+
def display_model_tuning_button(tuning_job: "sft.SupervisedTuningJob") -> None:
235+
"""Function to generate a link bound to the Vertex model tuning job."""
236+
if not is_ipython_available():
237+
return
238+
239+
try:
240+
resource_name = tuning_job.resource_name
241+
fields = tuning_job._parse_resource_name(resource_name)
242+
project = fields["project"]
243+
location = fields["location"]
244+
tuning_job_id = fields["tuning_job"]
245+
except AttributeError:
246+
_LOGGER.warning("Unable to parse tuning job metadata")
247+
return
248+
249+
uri = (
250+
"https://console.cloud.google.com/vertex-ai/generative/language/"
251+
+ f"locations/{location}/tuning/tuningJob/{tuning_job_id}"
252+
+ f"?project={project}"
253+
)
254+
display_link("View Tuning Job", uri, "tune")

vertexai/tuning/_supervised_tuning.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from typing import Dict, Literal, Optional, Union
1717

18+
from google.cloud.aiplatform.utils import _ipython_utils
1819
from google.cloud.aiplatform_v1beta1.types import (
1920
tuning_job as gca_tuning_job_types,
2021
)
@@ -87,12 +88,17 @@ def train(
8788
if isinstance(source_model, generative_models.GenerativeModel):
8889
source_model = source_model._prediction_resource_name.rpartition("/")[-1]
8990

90-
return SupervisedTuningJob._create( # pylint: disable=protected-access
91-
base_model=source_model,
92-
tuning_spec=supervised_tuning_spec,
93-
tuned_model_display_name=tuned_model_display_name,
94-
labels=labels,
91+
supervised_tuning_job = (
92+
SupervisedTuningJob._create( # pylint: disable=protected-access
93+
base_model=source_model,
94+
tuning_spec=supervised_tuning_spec,
95+
tuned_model_display_name=tuned_model_display_name,
96+
labels=labels,
97+
)
9598
)
99+
_ipython_utils.display_model_tuning_button(supervised_tuning_job)
100+
101+
return supervised_tuning_job
96102

97103

98104
class SupervisedTuningJob(_tuning.TuningJob):

0 commit comments

Comments
 (0)