Skip to content

Commit ec4ec8f

Browse files
yinghsienwucopybara-github
authored andcommitted
BREAKING_CHANGE: deprecate Vertex SDK data science package
PiperOrigin-RevId: 640614179
1 parent ebb8f62 commit ec4ec8f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+5
-19424
lines changed

tests/unit/aiplatform/test_language_models.py

-88
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
prediction_service as gca_prediction_service_v1beta1,
6666
)
6767

68-
import vertexai
6968
from vertexai.preview import (
7069
language_models as preview_language_models,
7170
)
@@ -4736,93 +4735,6 @@ def test_batch_prediction_for_text_embedding(self):
47364735
model_parameters={},
47374736
)
47384737

4739-
def test_text_generation_top_level_from_pretrained_preview(self):
4740-
"""Tests the text generation model."""
4741-
aiplatform.init(
4742-
project=_TEST_PROJECT,
4743-
location=_TEST_LOCATION,
4744-
)
4745-
with mock.patch.object(
4746-
target=model_garden_service_client.ModelGardenServiceClient,
4747-
attribute="get_publisher_model",
4748-
return_value=gca_publisher_model.PublisherModel(
4749-
_TEXT_BISON_PUBLISHER_MODEL_DICT
4750-
),
4751-
) as mock_get_publisher_model:
4752-
model = vertexai.preview.from_pretrained(
4753-
foundation_model_name="text-bison@001"
4754-
)
4755-
4756-
assert isinstance(model, preview_language_models.TextGenerationModel)
4757-
4758-
mock_get_publisher_model.assert_called_with(
4759-
name="publishers/google/models/text-bison@001", retry=base._DEFAULT_RETRY
4760-
)
4761-
assert mock_get_publisher_model.call_count == 1
4762-
4763-
assert (
4764-
model._model_resource_name
4765-
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/text-bison@001"
4766-
)
4767-
4768-
# Test that methods on TextGenerationModel still work as expected
4769-
gca_predict_response = gca_prediction_service.PredictResponse()
4770-
gca_predict_response.predictions.append(_TEST_TEXT_GENERATION_PREDICTION)
4771-
4772-
with mock.patch.object(
4773-
target=prediction_service_client.PredictionServiceClient,
4774-
attribute="predict",
4775-
return_value=gca_predict_response,
4776-
):
4777-
response = model.predict(
4778-
"What is the best recipe for banana bread? Recipe:",
4779-
max_output_tokens=128,
4780-
temperature=0.0,
4781-
top_p=1.0,
4782-
top_k=5,
4783-
)
4784-
4785-
assert response.text == _TEST_TEXT_GENERATION_PREDICTION["content"]
4786-
assert (
4787-
response.raw_prediction_response.predictions[0]
4788-
== _TEST_TEXT_GENERATION_PREDICTION
4789-
)
4790-
assert (
4791-
response.safety_attributes["Violent"]
4792-
== _TEST_TEXT_GENERATION_PREDICTION["safetyAttributes"]["scores"][0]
4793-
)
4794-
4795-
def test_text_embedding_top_level_from_pretrained_preview(self):
4796-
"""Tests the text embedding model."""
4797-
aiplatform.init(
4798-
project=_TEST_PROJECT,
4799-
location=_TEST_LOCATION,
4800-
)
4801-
with mock.patch.object(
4802-
target=model_garden_service_client.ModelGardenServiceClient,
4803-
attribute="get_publisher_model",
4804-
return_value=gca_publisher_model.PublisherModel(
4805-
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
4806-
),
4807-
) as mock_get_publisher_model:
4808-
model = vertexai.preview.from_pretrained(
4809-
foundation_model_name="textembedding-gecko@001"
4810-
)
4811-
4812-
assert isinstance(model, preview_language_models.TextEmbeddingModel)
4813-
4814-
assert (
4815-
model._endpoint_name
4816-
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/textembedding-gecko@001"
4817-
)
4818-
4819-
mock_get_publisher_model.assert_called_with(
4820-
name="publishers/google/models/textembedding-gecko@001",
4821-
retry=base._DEFAULT_RETRY,
4822-
)
4823-
4824-
assert mock_get_publisher_model.call_count == 1
4825-
48264738

48274739
# TODO (b/285946649): add more test coverage before public preview release
48284740
@pytest.mark.usefixtures("google_auth_mock")

tests/unit/aiplatform/test_vision_models.py

+1-36
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from google.cloud.aiplatform.compat.types import (
4040
publisher_model as gca_publisher_model,
4141
)
42-
import vertexai
42+
4343
from vertexai import vision_models as ga_vision_models
4444
from vertexai.preview import (
4545
vision_models as preview_vision_models,
@@ -221,48 +221,13 @@ def _get_image_generation_model(
221221

222222
return model
223223

224-
def _get_preview_image_generation_model_top_level_from_pretrained(
225-
self,
226-
) -> preview_vision_models.ImageGenerationModel:
227-
"""Gets the image generation model from the top-level vertexai.preview.from_pretrained method."""
228-
aiplatform.init(
229-
project=_TEST_PROJECT,
230-
location=_TEST_LOCATION,
231-
)
232-
with mock.patch.object(
233-
target=model_garden_service_client.ModelGardenServiceClient,
234-
attribute="get_publisher_model",
235-
return_value=gca_publisher_model.PublisherModel(
236-
_IMAGE_GENERATION_PUBLISHER_MODEL_DICT
237-
),
238-
) as mock_get_publisher_model:
239-
model = vertexai.preview.from_pretrained(
240-
foundation_model_name="imagegeneration@002"
241-
)
242-
243-
mock_get_publisher_model.assert_called_with(
244-
name="publishers/google/models/imagegeneration@002",
245-
retry=base._DEFAULT_RETRY,
246-
)
247-
248-
assert mock_get_publisher_model.call_count == 1
249-
250-
return model
251-
252224
def test_from_pretrained(self):
253225
model = self._get_image_generation_model()
254226
assert (
255227
model._endpoint_name
256228
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/imagegeneration@002"
257229
)
258230

259-
def test_top_level_from_pretrained_preview(self):
260-
model = self._get_preview_image_generation_model_top_level_from_pretrained()
261-
assert (
262-
model._endpoint_name
263-
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/imagegeneration@002"
264-
)
265-
266231
def test_generate_images(self):
267232
"""Tests the image generation model."""
268233
model = self._get_image_generation_model()

0 commit comments

Comments
 (0)