Skip to content

Commit e4b23a2

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Support tuning for the code-bison model (preview)
PiperOrigin-RevId: 550751052
1 parent 75eb777 commit e4b23a2

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

vertexai/_model_garden/_model_garden_models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
_SUPPORTED_PUBLISHERS = ["google"]
3232

3333
_SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = {
34-
"text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0"
34+
"text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0",
35+
"code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0",
3536
}
3637

3738
_SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset(

vertexai/language_models/_language_models.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str:
5454
return tuning_model_id.replace(
5555
"text-bison-", "publishers/google/models/text-bison@"
5656
)
57-
raise ValueError(f"Unsupported tuning model ID {tuning_model_id}")
57+
if "/" not in tuning_model_id:
58+
return "publishers/google/models/" + tuning_model_id
59+
return tuning_model_id
5860

5961

6062
class _LanguageModel(_model_garden_models._ModelGardenModel):
@@ -1007,6 +1009,10 @@ def predict(
10071009
)
10081010

10091011

1012+
class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
1013+
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
1014+
1015+
10101016
###### Model tuning
10111017
# Currently, tuning can only work in this location
10121018
_TUNING_LOCATIONS = ("europe-west4", "us-central1")

vertexai/preview/language_models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,21 @@
1616

1717
from vertexai.language_models._language_models import (
1818
_PreviewChatModel,
19+
_PreviewCodeGenerationModel,
1920
_PreviewTextEmbeddingModel,
2021
_PreviewTextGenerationModel,
2122
ChatMessage,
2223
ChatModel,
2324
ChatSession,
2425
CodeChatModel,
2526
CodeChatSession,
26-
CodeGenerationModel,
2727
InputOutputTextPair,
2828
TextEmbedding,
2929
TextGenerationResponse,
3030
)
3131

3232
ChatModel = _PreviewChatModel
33+
CodeGenerationModel = _PreviewCodeGenerationModel
3334
TextGenerationModel = _PreviewTextGenerationModel
3435
TextEmbeddingModel = _PreviewTextEmbeddingModel
3536

0 commit comments

Comments
 (0)