Skip to content

Commit 3a97c52

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: LLM - Added tuning support for chat-bison models
PiperOrigin-RevId: 555782339
1 parent 06c9d18 commit 3a97c52

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

tests/unit/aiplatform/test_language_models.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
language_models as preview_language_models,
5858
)
5959
from vertexai import language_models
60+
from vertexai.language_models import _language_models
6061
from google.cloud.aiplatform_v1 import Execution as GapicExecution
6162
from google.cloud.aiplatform.compat.types import (
6263
encryption_spec as gca_encryption_spec,
@@ -471,7 +472,7 @@ def get_endpoint_mock():
471472
@pytest.fixture
472473
def mock_get_tuned_model(get_endpoint_mock):
473474
with mock.patch.object(
474-
preview_language_models.TextGenerationModel, "get_tuned_model"
475+
_language_models._TunableModelMixin, "get_tuned_model"
475476
) as mock_text_generation_model:
476477
mock_text_generation_model._model_id = (
477478
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
@@ -634,7 +635,7 @@ def test_text_generation_ga(self):
634635
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
635636
indirect=True,
636637
)
637-
def test_tune_model(
638+
def test_tune_text_generation_model(
638639
self,
639640
mock_pipeline_service_create,
640641
mock_pipeline_job_get,
@@ -680,6 +681,49 @@ def test_tune_model(
680681
== _TEST_ENCRYPTION_KEY_NAME
681682
)
682683

684+
@pytest.mark.parametrize(
685+
"job_spec",
686+
[_TEST_PIPELINE_SPEC_JSON],
687+
)
688+
@pytest.mark.parametrize(
689+
"mock_request_urlopen",
690+
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
691+
indirect=True,
692+
)
693+
def test_tune_chat_model(
694+
self,
695+
mock_pipeline_service_create,
696+
mock_pipeline_job_get,
697+
mock_pipeline_bucket_exists,
698+
job_spec,
699+
mock_load_yaml_and_json,
700+
mock_gcs_from_string,
701+
mock_gcs_upload,
702+
mock_request_urlopen,
703+
mock_get_tuned_model,
704+
):
705+
"""Tests tuning a chat model."""
706+
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
707+
with mock.patch.object(
708+
target=model_garden_service_client.ModelGardenServiceClient,
709+
attribute="get_publisher_model",
710+
return_value=gca_publisher_model.PublisherModel(
711+
_CHAT_BISON_PUBLISHER_MODEL_DICT
712+
),
713+
):
714+
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")
715+
716+
model.tune_model(
717+
training_data=_TEST_TEXT_BISON_TRAINING_DF,
718+
tuning_job_location="europe-west4",
719+
tuned_model_location="us-central1",
720+
)
721+
call_kwargs = mock_pipeline_service_create.call_args[1]
722+
pipeline_arguments = call_kwargs[
723+
"pipeline_job"
724+
].runtime_config.parameter_values
725+
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"
726+
683727
@pytest.mark.usefixtures(
684728
"get_model_with_tuned_version_label_mock",
685729
"get_endpoint_with_models_mock",

vertexai/_model_garden/_model_garden_models.py

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
_SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = {
3434
"text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0",
3535
"code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0",
36+
"chat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
3637
}
3738

3839
_SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset(

vertexai/language_models/_language_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ class ChatModel(_ChatModelBase):
692692
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"
693693

694694

695-
class _PreviewChatModel(ChatModel):
695+
class _PreviewChatModel(ChatModel, _TunableModelMixin):
696696
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE
697697

698698

0 commit comments

Comments
 (0)