Skip to content

Commit af6e455

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added tuning support for codechat-bison models
PiperOrigin-RevId: 555829035
1 parent 3a97c52 commit af6e455

File tree

4 files changed

+54
-1
lines changed

4 files changed

+54
-1
lines changed

tests/unit/aiplatform/test_language_models.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,53 @@ def test_tune_chat_model(
724724
].runtime_config.parameter_values
725725
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"
726726

727+
@pytest.mark.parametrize(
728+
"job_spec",
729+
[_TEST_PIPELINE_SPEC_JSON],
730+
)
731+
@pytest.mark.parametrize(
732+
"mock_request_urlopen",
733+
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
734+
indirect=True,
735+
)
736+
def test_tune_code_chat_model(
737+
self,
738+
mock_pipeline_service_create,
739+
mock_pipeline_job_get,
740+
mock_pipeline_bucket_exists,
741+
job_spec,
742+
mock_load_yaml_and_json,
743+
mock_gcs_from_string,
744+
mock_gcs_upload,
745+
mock_request_urlopen,
746+
mock_get_tuned_model,
747+
):
748+
"""Tests tuning a code chat model."""
749+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
750+
with mock.patch.object(
751+
target=model_garden_service_client.ModelGardenServiceClient,
752+
attribute="get_publisher_model",
753+
return_value=gca_publisher_model.PublisherModel(
754+
_CODECHAT_BISON_PUBLISHER_MODEL_DICT
755+
),
756+
):
757+
model = preview_language_models.CodeChatModel.from_pretrained(
758+
"codechat-bison@001"
759+
)
760+
761+
# The tune_model call needs to be inside the PublisherModel mock
762+
# since it gets a new PublisherModel when tuning completes.
763+
model.tune_model(
764+
training_data=_TEST_TEXT_BISON_TRAINING_DF,
765+
tuning_job_location="europe-west4",
766+
tuned_model_location="us-central1",
767+
)
768+
call_kwargs = mock_pipeline_service_create.call_args[1]
769+
pipeline_arguments = call_kwargs[
770+
"pipeline_job"
771+
].runtime_config.parameter_values
772+
assert pipeline_arguments["large_model_reference"] == "codechat-bison@001"
773+
727774
@pytest.mark.usefixtures(
728775
"get_model_with_tuned_version_label_mock",
729776
"get_endpoint_with_models_mock",

vertexai/_model_garden/_model_garden_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0",
3535
"code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0",
3636
"chat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
37+
"codechat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
3738
}
3839

3940
_SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset(

vertexai/language_models/_language_models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,10 @@ def start_chat(
739739
)
740740

741741

742+
class _PreviewCodeChatModel(CodeChatModel, _TunableModelMixin):
743+
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
744+
745+
742746
class _ChatSessionBase:
743747
"""_ChatSessionBase is a base class for all chat sessions."""
744748

vertexai/preview/language_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,21 @@
1616

1717
from vertexai.language_models._language_models import (
1818
_PreviewChatModel,
19+
_PreviewCodeChatModel,
1920
_PreviewCodeGenerationModel,
2021
_PreviewTextEmbeddingModel,
2122
_PreviewTextGenerationModel,
2223
ChatMessage,
2324
ChatModel,
2425
ChatSession,
25-
CodeChatModel,
2626
CodeChatSession,
2727
InputOutputTextPair,
2828
TextEmbedding,
2929
TextGenerationResponse,
3030
)
3131

3232
ChatModel = _PreviewChatModel
33+
CodeChatModel = _PreviewCodeChatModel
3334
CodeGenerationModel = _PreviewCodeGenerationModel
3435
TextGenerationModel = _PreviewTextGenerationModel
3536
TextEmbeddingModel = _PreviewTextEmbeddingModel

0 commit comments

Comments
 (0)