|
101 | 101 | "prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/text_generation_1.0.0.yaml",
|
102 | 102 | },
|
103 | 103 | }
|
104 |
| - |
| 104 | +_TEXT_GECKO_PUBLISHER_MODEL_DICT = { |
| 105 | + "name": "publishers/google/models/textembedding-gecko", |
| 106 | + "version_id": "003", |
| 107 | + "open_source_category": "PROPRIETARY", |
| 108 | + "launch_stage": gca_publisher_model.PublisherModel.LaunchStage.GA, |
| 109 | + "publisher_model_template": "projects/{user-project}/locations/{location}/publishers/google/models/textembedding-gecko@003", |
| 110 | + "predict_schemata": { |
| 111 | + "instance_schema_uri": "gs://google-cloud-aiplatform/schema/predict/instance/text_embedding_1.0.0.yaml", |
| 112 | + "parameters_schema_uri": "gs://google-cloud-aiplatfrom/schema/predict/params/text_embedding_1.0.0.yaml", |
| 113 | + "prediction_schema_uri": "gs://google-cloud-aiplatform/schema/predict/prediction/text_embedding_1.0.0.yaml", |
| 114 | + }, |
| 115 | +} |
105 | 116 | _CHAT_BISON_PUBLISHER_MODEL_DICT = {
|
106 | 117 | "name": "publishers/google/models/chat-bison",
|
107 | 118 | "version_id": "001",
|
@@ -528,6 +539,105 @@ def reverse_string_2(s):""",
|
528 | 539 | },
|
529 | 540 | )
|
530 | 541 |
|
| 542 | +_EMBEDING_MODEL_TUNING_PIPELINE_SPEC = { |
| 543 | + "components": {}, |
| 544 | + "deploymentSpec": {}, |
| 545 | + "pipelineInfo": { |
| 546 | + "description": "Pipeline definition for v1.1.x embedding tuning pipelines.", |
| 547 | + "name": "tune-text-embedding-model", |
| 548 | + }, |
| 549 | + "root": { |
| 550 | + "dag": {"tasks": {}}, |
| 551 | + "inputDefinitions": { |
| 552 | + "parameters": { |
| 553 | + "accelerator_count": { |
| 554 | + "defaultValue": 4, |
| 555 | + "description": "how many accelerators to use when running the\ncontainer.", |
| 556 | + "isOptional": True, |
| 557 | + "parameterType": "NUMBER_INTEGER", |
| 558 | + }, |
| 559 | + "accelerator_type": { |
| 560 | + "defaultValue": "NVIDIA_TESLA_V100", |
| 561 | + "description": "the accelerator type for running the trainer component.", |
| 562 | + "isOptional": True, |
| 563 | + "parameterType": "STRING", |
| 564 | + }, |
| 565 | + "base_model_version_id": { |
| 566 | + "defaultValue": "textembedding-gecko@001", |
| 567 | + "description": "which base model to tune. This may be any stable\nnumbered version, for example `textembedding-gecko@001`.", |
| 568 | + "isOptional": True, |
| 569 | + "parameterType": "STRING", |
| 570 | + }, |
| 571 | + "batch_size": { |
| 572 | + "defaultValue": 128, |
| 573 | + "description": "training batch size.", |
| 574 | + "isOptional": True, |
| 575 | + "parameterType": "NUMBER_INTEGER", |
| 576 | + }, |
| 577 | + "corpus_path": { |
| 578 | + "description": "the GCS path to the corpus data location.", |
| 579 | + "parameterType": "STRING", |
| 580 | + }, |
| 581 | + "iterations": { |
| 582 | + "defaultValue": 1000, |
| 583 | + "description": "the number of steps to perform fine-tuning.", |
| 584 | + "isOptional": True, |
| 585 | + "parameterType": "NUMBER_INTEGER", |
| 586 | + }, |
| 587 | + "location": { |
| 588 | + "defaultValue": "us-central1", |
| 589 | + "description": "GCP region to run the pipeline.", |
| 590 | + "isOptional": True, |
| 591 | + "parameterType": "STRING", |
| 592 | + }, |
| 593 | + "machine_type": { |
| 594 | + "defaultValue": "n1-standard-16", |
| 595 | + "description": "the type of the machine to run the trainer component. For\nmore details about this input config, see:\nhttps://cloud.google.com/vertex-ai/docs/training/configure-compute.", |
| 596 | + "isOptional": True, |
| 597 | + "parameterType": "STRING", |
| 598 | + }, |
| 599 | + "model_display_name": { |
| 600 | + "defaultValue": "tuned-text-embedding-model", |
| 601 | + "description": "output model display name.", |
| 602 | + "isOptional": True, |
| 603 | + "parameterType": "STRING", |
| 604 | + }, |
| 605 | + "project": { |
| 606 | + "description": "user's project id.", |
| 607 | + "parameterType": "STRING", |
| 608 | + }, |
| 609 | + "queries_path": { |
| 610 | + "description": "the GCS path to the queries location.", |
| 611 | + "parameterType": "STRING", |
| 612 | + }, |
| 613 | + "task_type": { |
| 614 | + "defaultValue": "DEFAULT", |
| 615 | + "description": "the task type expected to be used during inference. Valid\nvalues are `DEFAULT`, `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`,\n`SEMANTIC_SIMILARITY`, `CLASSIFICATION`, and `CLUSTERING`.", |
| 616 | + "isOptional": True, |
| 617 | + "parameterType": "STRING", |
| 618 | + }, |
| 619 | + "test_label_path": { |
| 620 | + "defaultValue": "", |
| 621 | + "description": "the GCS path to the test label data location.", |
| 622 | + "isOptional": True, |
| 623 | + "parameterType": "STRING", |
| 624 | + }, |
| 625 | + "train_label_path": { |
| 626 | + "description": "the GCS path to the train label data location.", |
| 627 | + "parameterType": "STRING", |
| 628 | + }, |
| 629 | + "validation_label_path": { |
| 630 | + "defaultValue": "", |
| 631 | + "description": "The GCS path to the validation label data location.", |
| 632 | + "isOptional": True, |
| 633 | + "parameterType": "STRING", |
| 634 | + }, |
| 635 | + } |
| 636 | + }, |
| 637 | + }, |
| 638 | + "schemaVersion": "2.1.0", |
| 639 | + "sdkVersion": "kfp-2.6.0", |
| 640 | +} |
531 | 641 | _TEST_PIPELINE_SPEC = {
|
532 | 642 | "components": {},
|
533 | 643 | "pipelineInfo": {"name": "evaluation-llm-text-generation-pipeline"},
|
@@ -641,6 +751,9 @@ def reverse_string_2(s):""",
|
641 | 751 | }
|
642 | 752 |
|
643 | 753 |
|
| 754 | +_EMBEDING_MODEL_TUNING_PIPELINE_SPEC_JSON = json.dumps( |
| 755 | + _EMBEDING_MODEL_TUNING_PIPELINE_SPEC, |
| 756 | +) |
644 | 757 | _TEST_PIPELINE_SPEC_JSON = json.dumps(
|
645 | 758 | _TEST_PIPELINE_SPEC,
|
646 | 759 | )
|
@@ -1460,6 +1573,18 @@ def mock_request_urlopen(request: str) -> Tuple[str, mock.MagicMock]:
|
1460 | 1573 | yield request.param, mock_urlopen
|
1461 | 1574 |
|
1462 | 1575 |
|
| 1576 | +@pytest.fixture |
| 1577 | +def mock_request_urlopen_gecko(request: str) -> Tuple[str, mock.MagicMock]: |
| 1578 | + data = _EMBEDING_MODEL_TUNING_PIPELINE_SPEC |
| 1579 | + with mock.patch.object(urllib_request, "urlopen") as mock_urlopen: |
| 1580 | + mock_read_response = mock.MagicMock() |
| 1581 | + mock_decode_response = mock.MagicMock() |
| 1582 | + mock_decode_response.return_value = json.dumps(data) |
| 1583 | + mock_read_response.return_value.decode = mock_decode_response |
| 1584 | + mock_urlopen.return_value.read = mock_read_response |
| 1585 | + yield request.param, mock_urlopen |
| 1586 | + |
| 1587 | + |
1463 | 1588 | @pytest.fixture
|
1464 | 1589 | def mock_request_urlopen_rlhf(request: str) -> Tuple[str, mock.MagicMock]:
|
1465 | 1590 | data = _TEST_RLHF_PIPELINE_SPEC
|
@@ -1528,6 +1653,21 @@ def get_endpoint_mock():
|
1528 | 1653 | yield get_endpoint_mock
|
1529 | 1654 |
|
1530 | 1655 |
|
| 1656 | +@pytest.fixture |
| 1657 | +def mock_get_tuned_embedding_model(get_endpoint_mock): |
| 1658 | + with mock.patch.object( |
| 1659 | + _language_models._TunableEmbeddingModelMixin, "get_tuned_model" |
| 1660 | + ) as mock_text_generation_model: |
| 1661 | + mock_text_generation_model.return_value._model_id = ( |
| 1662 | + test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME |
| 1663 | + ) |
| 1664 | + mock_text_generation_model.return_value._endpoint_name = ( |
| 1665 | + test_constants.EndpointConstants._TEST_ENDPOINT_NAME |
| 1666 | + ) |
| 1667 | + mock_text_generation_model.return_value._endpoint = get_endpoint_mock |
| 1668 | + yield mock_text_generation_model |
| 1669 | + |
| 1670 | + |
1531 | 1671 | @pytest.fixture
|
1532 | 1672 | def mock_get_tuned_model(get_endpoint_mock):
|
1533 | 1673 | with mock.patch.object(
|
@@ -2134,6 +2274,66 @@ def test_text_generation_response_repr(self):
|
2134 | 2274 | assert "blocked" in response_repr
|
2135 | 2275 | assert "Violent" in response_repr
|
2136 | 2276 |
|
| 2277 | + @pytest.mark.parametrize( |
| 2278 | + "job_spec", |
| 2279 | + [_EMBEDING_MODEL_TUNING_PIPELINE_SPEC_JSON], |
| 2280 | + ) |
| 2281 | + @pytest.mark.parametrize( |
| 2282 | + "mock_request_urlopen_gecko", |
| 2283 | + ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"], |
| 2284 | + indirect=True, |
| 2285 | + ) |
| 2286 | + def test_tune_text_embedding_model( |
| 2287 | + self, |
| 2288 | + mock_pipeline_service_create, |
| 2289 | + mock_pipeline_job_get, |
| 2290 | + mock_pipeline_bucket_exists, |
| 2291 | + job_spec, |
| 2292 | + mock_load_yaml_and_json, |
| 2293 | + mock_gcs_from_string, |
| 2294 | + mock_gcs_upload, |
| 2295 | + mock_request_urlopen_gecko, |
| 2296 | + mock_get_tuned_embedding_model, |
| 2297 | + ): |
| 2298 | + """Tests tuning the text embedding model.""" |
| 2299 | + aiplatform.init( |
| 2300 | + project=_TEST_PROJECT, |
| 2301 | + location=_TEST_LOCATION, |
| 2302 | + encryption_spec_key_name=_TEST_ENCRYPTION_KEY_NAME, |
| 2303 | + ) |
| 2304 | + with mock.patch.object( |
| 2305 | + target=model_garden_service_client.ModelGardenServiceClient, |
| 2306 | + attribute="get_publisher_model", |
| 2307 | + return_value=gca_publisher_model.PublisherModel( |
| 2308 | + _TEXT_GECKO_PUBLISHER_MODEL_DICT |
| 2309 | + ), |
| 2310 | + ): |
| 2311 | + model = language_models.TextEmbeddingModel.from_pretrained( |
| 2312 | + "textembedding-gecko@003" |
| 2313 | + ) |
| 2314 | + tuning_job = model.tune_model( |
| 2315 | + training_data="gs://bucket/training.tsv", |
| 2316 | + corpus_data="gs://bucket/corpus.jsonl", |
| 2317 | + queries_data="gs://bucket/queries.jsonl", |
| 2318 | + test_data="gs://bucket/test.tsv", |
| 2319 | + tuned_model_location="us-central1", |
| 2320 | + train_steps=10, |
| 2321 | + accelerator="NVIDIA_TESLA_A100", |
| 2322 | + ) |
| 2323 | + call_kwargs = mock_pipeline_service_create.call_args[1] |
| 2324 | + pipeline_arguments = call_kwargs[ |
| 2325 | + "pipeline_job" |
| 2326 | + ].runtime_config.parameter_values |
| 2327 | + assert pipeline_arguments["iterations"] == 10 |
| 2328 | + assert pipeline_arguments["accelerator_type"] == "NVIDIA_TESLA_A100" |
| 2329 | + |
| 2330 | + # Testing the tuned model |
| 2331 | + tuned_model = tuning_job.get_tuned_model() |
| 2332 | + assert ( |
| 2333 | + tuned_model._endpoint_name |
| 2334 | + == test_constants.EndpointConstants._TEST_ENDPOINT_NAME |
| 2335 | + ) |
| 2336 | + |
2137 | 2337 | @pytest.mark.parametrize(
|
2138 | 2338 | "job_spec",
|
2139 | 2339 | [_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_JOB],
|
|
0 commit comments