|
65 | 65 | prediction_service as gca_prediction_service_v1beta1,
|
66 | 66 | )
|
67 | 67 |
|
68 |
| -import vertexai |
69 | 68 | from vertexai.preview import (
|
70 | 69 | language_models as preview_language_models,
|
71 | 70 | )
|
@@ -4736,93 +4735,6 @@ def test_batch_prediction_for_text_embedding(self):
|
4736 | 4735 | model_parameters={},
|
4737 | 4736 | )
|
4738 | 4737 |
|
4739 |
| - def test_text_generation_top_level_from_pretrained_preview(self): |
4740 |
| - """Tests the text generation model.""" |
4741 |
| - aiplatform.init( |
4742 |
| - project=_TEST_PROJECT, |
4743 |
| - location=_TEST_LOCATION, |
4744 |
| - ) |
4745 |
| - with mock.patch.object( |
4746 |
| - target=model_garden_service_client.ModelGardenServiceClient, |
4747 |
| - attribute="get_publisher_model", |
4748 |
| - return_value=gca_publisher_model.PublisherModel( |
4749 |
| - _TEXT_BISON_PUBLISHER_MODEL_DICT |
4750 |
| - ), |
4751 |
| - ) as mock_get_publisher_model: |
4752 |
| - model = vertexai.preview.from_pretrained( |
4753 |
| - foundation_model_name="text-bison@001" |
4754 |
| - ) |
4755 |
| - |
4756 |
| - assert isinstance(model, preview_language_models.TextGenerationModel) |
4757 |
| - |
4758 |
| - mock_get_publisher_model.assert_called_with( |
4759 |
| - name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY |
4760 |
| - ) |
4761 |
| - assert mock_get_publisher_model.call_count == 1 |
4762 |
| - |
4763 |
| - assert ( |
4764 |
| - model._model_resource_name |
4765 |
| - == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001" |
4766 |
| - ) |
4767 |
| - |
4768 |
| - # Test that methods on TextGenerationModel still work as expected |
4769 |
| - gca_predict_response = gca_prediction_service.PredictResponse() |
4770 |
| - gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION) |
4771 |
| - |
4772 |
| - with mock.patch.object( |
4773 |
| - target=prediction_service_client.PredictionServiceClient, |
4774 |
| - attribute="predict", |
4775 |
| - return_value=gca_predict_response, |
4776 |
| - ): |
4777 |
| - response = model.predict( |
4778 |
| - "What is the best recipe for banana bread? Recipe:", |
4779 |
| - max_output_tokens=128, |
4780 |
| - temperature=0.0, |
4781 |
| - top_p=1.0, |
4782 |
| - top_k=5, |
4783 |
| - ) |
4784 |
| - |
4785 |
| - assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"] |
4786 |
| - assert ( |
4787 |
| - response.raw_prediction_response.predictions[0] |
4788 |
| - == _TEST_TEXT_GENERATION_PREDICTION |
4789 |
| - ) |
4790 |
| - assert ( |
4791 |
| - response.safety_attributes["Violent"] |
4792 |
| - == _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0] |
4793 |
| - ) |
4794 |
| - |
4795 |
| - def test_text_embedding_top_level_from_pretrained_preview(self): |
4796 |
| - """Tests the text embedding model.""" |
4797 |
| - aiplatform.init( |
4798 |
| - project=_TEST_PROJECT, |
4799 |
| - location=_TEST_LOCATION, |
4800 |
| - ) |
4801 |
| - with mock.patch.object( |
4802 |
| - target=model_garden_service_client.ModelGardenServiceClient, |
4803 |
| - attribute="get_publisher_model", |
4804 |
| - return_value=gca_publisher_model.PublisherModel( |
4805 |
| - _TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT |
4806 |
| - ), |
4807 |
| - ) as mock_get_publisher_model: |
4808 |
| - model = vertexai.preview.from_pretrained( |
4809 |
| - foundation_model_name="textembedding-gecko@001" |
4810 |
| - ) |
4811 |
| - |
4812 |
| - assert isinstance(model, preview_language_models.TextEmbeddingModel) |
4813 |
| - |
4814 |
| - assert ( |
4815 |
| - model._endpoint_name |
4816 |
| - == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001" |
4817 |
| - ) |
4818 |
| - |
4819 |
| - mock_get_publisher_model.assert_called_with( |
4820 |
| - name="publishers/google/models/textembedding-gecko@001", |
4821 |
| - retry=base._DEFAULT_RETRY, |
4822 |
| - ) |
4823 |
| - |
4824 |
| - assert mock_get_publisher_model.call_count == 1 |
4825 |
| - |
4826 | 4738 |
|
4827 | 4739 | # TODO (b/285946649): add more test coverage before public preview release
|
4828 | 4740 | @pytest.mark.usefixtures("google_auth_mock")
|
|
0 commit comments