Skip to content

Commit efb8413

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Make count_tokens generally-available at TextEmbeddingModel.
PiperOrigin-RevId: 654133506
1 parent e5d087f commit efb8413

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

tests/unit/aiplatform/test_language_models.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -4526,7 +4526,48 @@ def test_text_embedding(self):
45264526
== expected_embedding["statistics"]["truncated"]
45274527
)
45284528

4529-
def test_text_embedding_preview_count_tokens(self):
4529+
def test_text_embedding_count_tokens_ga(self):
4530+
"""Tests the text embedding model."""
4531+
aiplatform.init(
4532+
project=_TEST_PROJECT,
4533+
location=_TEST_LOCATION,
4534+
)
4535+
with mock.patch.object(
4536+
target=model_garden_service_client.ModelGardenServiceClient,
4537+
attribute="get_publisher_model",
4538+
return_value=gca_publisher_model.PublisherModel(
4539+
_TEXT_EMBEDDING_GECKO_PUBLISHER_MODEL_DICT
4540+
),
4541+
):
4542+
model = language_models.TextEmbeddingModel.from_pretrained(
4543+
"textembedding-gecko@001"
4544+
)
4545+
4546+
gca_count_tokens_response = (
4547+
gca_prediction_service_v1beta1.CountTokensResponse(
4548+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
4549+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
4550+
"total_billable_characters"
4551+
],
4552+
)
4553+
)
4554+
4555+
with mock.patch.object(
4556+
target=prediction_service_client_v1beta1.PredictionServiceClient,
4557+
attribute="count_tokens",
4558+
return_value=gca_count_tokens_response,
4559+
):
4560+
response = model.count_tokens(["What is life?"])
4561+
4562+
assert (
4563+
response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
4564+
)
4565+
assert (
4566+
response.total_billable_characters
4567+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
4568+
)
4569+
4570+
def test_text_embedding_count_tokens_preview(self):
45304571
"""Tests the text embedding model."""
45314572
aiplatform.init(
45324573
project=_TEST_PROJECT,

vertexai/language_models/_language_models.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2417,7 +2417,11 @@ class _TunableTextEmbeddingModelMixin(_PreviewTunableTextEmbeddingModelMixin):
24172417
pass
24182418

24192419

2420-
class TextEmbeddingModel(_TextEmbeddingModel, _TunableTextEmbeddingModelMixin):
2420+
class TextEmbeddingModel(
2421+
_TextEmbeddingModel,
2422+
_TunableTextEmbeddingModelMixin,
2423+
_CountTokensMixin,
2424+
):
24212425
__module__ = "vertexai.language_models"
24222426

24232427

0 commit comments

Comments
 (0)