Skip to content

Commit 2cef97f

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Tuning - Added support for tuned model rebasing. Added rebase_tuned_model to vertexai.preview.tuning.sft.
PiperOrigin-RevId: 688795387
1 parent da76253 commit 2cef97f

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

vertexai/preview/tuning/sft.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@
2020
train,
2121
SupervisedTuningJob,
2222
)
23+
from vertexai.tuning._tuning import (
24+
rebase_tuned_model,
25+
)
2326

2427
__all__ = [
28+
"rebase_tuned_model",
2529
"train",
2630
"SupervisedTuningJob",
2731
]

vertexai/tuning/_tuning.py

+82
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,85 @@ def _dashboard_url(self) -> str:
257257
job = list(fields.values())[0]
258258
url = f"https://console.cloud.google.com/vertex-ai/generative/language/locations/{location}/tuning/tuningJob/{job}?project={project}"
259259
return url
260+
261+
262+
def rebase_tuned_model(
263+
tuned_model_ref: str,
264+
*,
265+
# TODO(b/372291558): Add support for overriding tuning job config
266+
# tuning_job_config: Optional["TuningJob"] = None,
267+
artifact_destination: Optional[str] = None,
268+
deploy_to_same_endpoint: Optional[bool] = False,
269+
):
270+
"""Re-runs fine tuning on top of a new foundational model.
271+
272+
Takes a legacy Tuned GenAI model Reference and creates a TuningJob based
273+
on a new model.
274+
275+
Args:
276+
tuned_model_ref: Required. TunedModel reference to retrieve
277+
the legacy model information.
278+
tuning_job_config: The TuningJob to be updated. Users
279+
can use this TuningJob field to overwrite tuning
280+
configs.
281+
artifact_destination: The Google Cloud Storage location to write the artifacts.
282+
deploy_to_same_endpoint:
283+
Optional. By default, bison to gemini
284+
migration will always create new model/endpoint,
285+
but for gemini-1.0 to gemini-1.5 migration, we
286+
default deploy to the same endpoint. See details
287+
in this Section.
288+
289+
Returns:
290+
The new TuningJob.
291+
"""
292+
parent = aiplatform_initializer.global_config.common_location_path(
293+
project=aiplatform_initializer.global_config.project,
294+
location=aiplatform_initializer.global_config.location,
295+
)
296+
297+
if "/tuningJobs/" in tuned_model_ref:
298+
gapic_tuned_model_ref = gca_types.TunedModelRef(
299+
tuning_job=tuned_model_ref,
300+
)
301+
elif "/pipelineJobs/" in tuned_model_ref:
302+
gapic_tuned_model_ref = gca_types.TunedModelRef(
303+
pipeline_job=tuned_model_ref,
304+
)
305+
elif "/models/" in tuned_model_ref:
306+
gapic_tuned_model_ref = gca_types.TunedModelRef(
307+
tuned_model=tuned_model_ref,
308+
)
309+
else:
310+
raise ValueError(f"Unsupported tuned_model_ref: {tuned_model_ref}.")
311+
312+
# gapic_tuning_job_config = tuning_job._gca_resource if tuning_job else None
313+
gapic_tuning_job_config = None
314+
315+
gapic_artifact_destination = (
316+
gca_types.GcsDestination(output_uri_prefix=artifact_destination)
317+
if artifact_destination
318+
else None
319+
)
320+
321+
api_client: gen_ai_tuning_service_v1beta1.GenAiTuningServiceClient = (
322+
TuningJob._instantiate_client(
323+
location=aiplatform_initializer.global_config.location,
324+
credentials=aiplatform_initializer.global_config.credentials,
325+
)
326+
)
327+
rebase_operation = api_client.rebase_tuned_model(
328+
gca_types.RebaseTunedModelRequest(
329+
parent=parent,
330+
tuned_model_ref=gapic_tuned_model_ref,
331+
tuning_job=gapic_tuning_job_config,
332+
artifact_destination=gapic_artifact_destination,
333+
deploy_to_same_endpoint=deploy_to_same_endpoint,
334+
)
335+
)
336+
_LOGGER.log_create_with_lro(TuningJob, lro=rebase_operation)
337+
gapic_rebase_tuning_job: gca_types.TuningJob = rebase_operation.result()
338+
rebase_tuning_job = TuningJob._construct_sdk_resource_from_gapic(
339+
gapic_resource=gapic_rebase_tuning_job,
340+
)
341+
return rebase_tuning_job

0 commit comments

Comments
 (0)