File tree 2 files changed +35
-5
lines changed
google/cloud/aiplatform/utils
2 files changed +35
-5
lines changed Original file line number Diff line number Diff line change 26
26
from google .cloud .aiplatform .metadata import experiment_resources
27
27
from google .cloud .aiplatform .metadata import experiment_run_resource
28
28
from google .cloud .aiplatform import model_evaluation
29
+ from vertexai .preview .tuning import sft
29
30
30
31
_LOGGER = base .Logger (__name__ )
31
32
@@ -228,3 +229,26 @@ def display_model_evaluation_button(
228
229
+ f"{ evaluation_id } ?project={ project } "
229
230
)
230
231
display_link ("View Model Evaluation" , uri , "lightbulb" )
232
+
233
+
234
+ def display_model_tuning_button (tuning_job : "sft.SupervisedTuningJob" ) -> None :
235
+ """Function to generate a link bound to the Vertex model tuning job."""
236
+ if not is_ipython_available ():
237
+ return
238
+
239
+ try :
240
+ resource_name = tuning_job .resource_name
241
+ fields = tuning_job ._parse_resource_name (resource_name )
242
+ project = fields ["project" ]
243
+ location = fields ["location" ]
244
+ tuning_job_id = fields ["tuning_job" ]
245
+ except AttributeError :
246
+ _LOGGER .warning ("Unable to parse tuning job metadata" )
247
+ return
248
+
249
+ uri = (
250
+ "https://console.cloud.google.com/vertex-ai/generative/language/"
251
+ + f"locations/{ location } /tuning/tuningJob/{ tuning_job_id } "
252
+ + f"?project={ project } "
253
+ )
254
+ display_link ("View Tuning Job" , uri , "tune" )
Original file line number Diff line number Diff line change 15
15
16
16
from typing import Dict , Literal , Optional , Union
17
17
18
+ from google .cloud .aiplatform .utils import _ipython_utils
18
19
from google .cloud .aiplatform_v1beta1 .types import (
19
20
tuning_job as gca_tuning_job_types ,
20
21
)
@@ -87,12 +88,17 @@ def train(
87
88
if isinstance (source_model , generative_models .GenerativeModel ):
88
89
source_model = source_model ._prediction_resource_name .rpartition ("/" )[- 1 ]
89
90
90
- return SupervisedTuningJob ._create ( # pylint: disable=protected-access
91
- base_model = source_model ,
92
- tuning_spec = supervised_tuning_spec ,
93
- tuned_model_display_name = tuned_model_display_name ,
94
- labels = labels ,
91
+ supervised_tuning_job = (
92
+ SupervisedTuningJob ._create ( # pylint: disable=protected-access
93
+ base_model = source_model ,
94
+ tuning_spec = supervised_tuning_spec ,
95
+ tuned_model_display_name = tuned_model_display_name ,
96
+ labels = labels ,
97
+ )
95
98
)
99
+ _ipython_utils .display_model_tuning_button (supervised_tuning_job )
100
+
101
+ return supervised_tuning_job
96
102
97
103
98
104
class SupervisedTuningJob (_tuning .TuningJob ):
You can’t perform that action at this time.
0 commit comments