|
57 | 57 | language_models as preview_language_models,
|
58 | 58 | )
|
59 | 59 | from vertexai import language_models
|
| 60 | +from vertexai.language_models import _language_models |
60 | 61 | from google.cloud.aiplatform_v1 import Execution as GapicExecution
|
61 | 62 | from google.cloud.aiplatform.compat.types import (
|
62 | 63 | encryption_spec as gca_encryption_spec,
|
@@ -471,7 +472,7 @@ def get_endpoint_mock():
|
471 | 472 | @pytest.fixture
|
472 | 473 | def mock_get_tuned_model(get_endpoint_mock):
|
473 | 474 | with mock.patch.object(
|
474 |
| - preview_language_models.TextGenerationModel, "get_tuned_model" |
| 475 | + _language_models._TunableModelMixin, "get_tuned_model" |
475 | 476 | ) as mock_text_generation_model:
|
476 | 477 | mock_text_generation_model._model_id = (
|
477 | 478 | test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
|
@@ -634,7 +635,7 @@ def test_text_generation_ga(self):
|
634 | 635 | ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
|
635 | 636 | indirect=True,
|
636 | 637 | )
|
637 |
| - def test_tune_model( |
| 638 | + def test_tune_text_generation_model( |
638 | 639 | self,
|
639 | 640 | mock_pipeline_service_create,
|
640 | 641 | mock_pipeline_job_get,
|
@@ -680,6 +681,49 @@ def test_tune_model(
|
680 | 681 | == _TEST_ENCRYPTION_KEY_NAME
|
681 | 682 | )
|
682 | 683 |
|
| 684 | + @pytest.mark.parametrize( |
| 685 | + "job_spec", |
| 686 | + [_TEST_PIPELINE_SPEC_JSON], |
| 687 | + ) |
| 688 | + @pytest.mark.parametrize( |
| 689 | + "mock_request_urlopen", |
| 690 | + ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"], |
| 691 | + indirect=True, |
| 692 | + ) |
| 693 | + def test_tune_chat_model( |
| 694 | + self, |
| 695 | + mock_pipeline_service_create, |
| 696 | + mock_pipeline_job_get, |
| 697 | + mock_pipeline_bucket_exists, |
| 698 | + job_spec, |
| 699 | + mock_load_yaml_and_json, |
| 700 | + mock_gcs_from_string, |
| 701 | + mock_gcs_upload, |
| 702 | + mock_request_urlopen, |
| 703 | + mock_get_tuned_model, |
| 704 | + ): |
| 705 | + """Tests tuning a chat model.""" |
| 706 | + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 707 | + with mock.patch.object( |
| 708 | + target=model_garden_service_client.ModelGardenServiceClient, |
| 709 | + attribute="get_publisher_model", |
| 710 | + return_value=gca_publisher_model.PublisherModel( |
| 711 | + _CHAT_BISON_PUBLISHER_MODEL_DICT |
| 712 | + ), |
| 713 | + ): |
| 714 | + model = preview_language_models.ChatModel.from_pretrained("chat-bison@001") |
| 715 | + |
| 716 | + model.tune_model( |
| 717 | + training_data=_TEST_TEXT_BISON_TRAINING_DF, |
| 718 | + tuning_job_location="europe-west4", |
| 719 | + tuned_model_location="us-central1", |
| 720 | + ) |
| 721 | + call_kwargs = mock_pipeline_service_create.call_args[1] |
| 722 | + pipeline_arguments = call_kwargs[ |
| 723 | + "pipeline_job" |
| 724 | + ].runtime_config.parameter_values |
| 725 | + assert pipeline_arguments["large_model_reference"] == "chat-bison@001" |
| 726 | + |
683 | 727 | @pytest.mark.usefixtures(
|
684 | 728 | "get_model_with_tuned_version_label_mock",
|
685 | 729 | "get_endpoint_with_models_mock",
|
|
0 commit comments