@@ -257,3 +257,85 @@ def _dashboard_url(self) -> str:
257
257
job = list (fields .values ())[0 ]
258
258
url = f"https://console.cloud.google.com/vertex-ai/generative/language/locations/{ location } /tuning/tuningJob/{ job } ?project={ project } "
259
259
return url
260
+
261
+
262
+ def rebase_tuned_model (
263
+ tuned_model_ref : str ,
264
+ * ,
265
+ # TODO(b/372291558): Add support for overriding tuning job config
266
+ # tuning_job_config: Optional["TuningJob"] = None,
267
+ artifact_destination : Optional [str ] = None ,
268
+ deploy_to_same_endpoint : Optional [bool ] = False ,
269
+ ):
270
+ """Re-runs fine tuning on top of a new foundational model.
271
+
272
+ Takes a legacy Tuned GenAI model Reference and creates a TuningJob based
273
+ on a new model.
274
+
275
+ Args:
276
+ tuned_model_ref: Required. TunedModel reference to retrieve
277
+ the legacy model information.
278
+ tuning_job_config: The TuningJob to be updated. Users
279
+ can use this TuningJob field to overwrite tuning
280
+ configs.
281
+ artifact_destination: The Google Cloud Storage location to write the artifacts.
282
+ deploy_to_same_endpoint:
283
+ Optional. By default, bison to gemini
284
+ migration will always create new model/endpoint,
285
+ but for gemini-1.0 to gemini-1.5 migration, we
286
+ default deploy to the same endpoint. See details
287
+ in this Section.
288
+
289
+ Returns:
290
+ The new TuningJob.
291
+ """
292
+ parent = aiplatform_initializer .global_config .common_location_path (
293
+ project = aiplatform_initializer .global_config .project ,
294
+ location = aiplatform_initializer .global_config .location ,
295
+ )
296
+
297
+ if "/tuningJobs/" in tuned_model_ref :
298
+ gapic_tuned_model_ref = gca_types .TunedModelRef (
299
+ tuning_job = tuned_model_ref ,
300
+ )
301
+ elif "/pipelineJobs/" in tuned_model_ref :
302
+ gapic_tuned_model_ref = gca_types .TunedModelRef (
303
+ pipeline_job = tuned_model_ref ,
304
+ )
305
+ elif "/models/" in tuned_model_ref :
306
+ gapic_tuned_model_ref = gca_types .TunedModelRef (
307
+ tuned_model = tuned_model_ref ,
308
+ )
309
+ else :
310
+ raise ValueError (f"Unsupported tuned_model_ref: { tuned_model_ref } ." )
311
+
312
+ # gapic_tuning_job_config = tuning_job._gca_resource if tuning_job else None
313
+ gapic_tuning_job_config = None
314
+
315
+ gapic_artifact_destination = (
316
+ gca_types .GcsDestination (output_uri_prefix = artifact_destination )
317
+ if artifact_destination
318
+ else None
319
+ )
320
+
321
+ api_client : gen_ai_tuning_service_v1beta1 .GenAiTuningServiceClient = (
322
+ TuningJob ._instantiate_client (
323
+ location = aiplatform_initializer .global_config .location ,
324
+ credentials = aiplatform_initializer .global_config .credentials ,
325
+ )
326
+ )
327
+ rebase_operation = api_client .rebase_tuned_model (
328
+ gca_types .RebaseTunedModelRequest (
329
+ parent = parent ,
330
+ tuned_model_ref = gapic_tuned_model_ref ,
331
+ tuning_job = gapic_tuning_job_config ,
332
+ artifact_destination = gapic_artifact_destination ,
333
+ deploy_to_same_endpoint = deploy_to_same_endpoint ,
334
+ )
335
+ )
336
+ _LOGGER .log_create_with_lro (TuningJob , lro = rebase_operation )
337
+ gapic_rebase_tuning_job : gca_types .TuningJob = rebase_operation .result ()
338
+ rebase_tuning_job = TuningJob ._construct_sdk_resource_from_gapic (
339
+ gapic_resource = gapic_rebase_tuning_job ,
340
+ )
341
+ return rebase_tuning_job
0 commit comments