Skip to content

Commit 96e7f7d

Browse files
sararobcopybara-github
authored andcommitted
feat: add preview count_tokens method to CodeGenerationModel
PiperOrigin-RevId: 575318395
1 parent 01989b1 commit 96e7f7d

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

tests/unit/aiplatform/test_language_models.py

+37
Original file line numberDiff line numberDiff line change
@@ -2771,6 +2771,43 @@ def test_code_generation_multiple_candidates(self):
27712771
response.candidates[0].text == _TEST_CODE_GENERATION_PREDICTION["content"]
27722772
)
27732773

2774+
def test_code_generation_preview_count_tokens(self):
2775+
"""Tests the count_tokens method in CodeGenerationModel."""
2776+
aiplatform.init(
2777+
project=_TEST_PROJECT,
2778+
location=_TEST_LOCATION,
2779+
)
2780+
with mock.patch.object(
2781+
target=model_garden_service_client.ModelGardenServiceClient,
2782+
attribute="get_publisher_model",
2783+
return_value=gca_publisher_model.PublisherModel(
2784+
_CODE_COMPLETION_BISON_PUBLISHER_MODEL_DICT
2785+
),
2786+
):
2787+
model = preview_language_models.CodeGenerationModel.from_pretrained(
2788+
"code-gecko@001"
2789+
)
2790+
2791+
gca_count_tokens_response = gca_prediction_service_v1beta1.CountTokensResponse(
2792+
total_tokens=_TEST_COUNT_TOKENS_RESPONSE["total_tokens"],
2793+
total_billable_characters=_TEST_COUNT_TOKENS_RESPONSE[
2794+
"total_billable_characters"
2795+
],
2796+
)
2797+
2798+
with mock.patch.object(
2799+
target=prediction_service_client_v1beta1.PredictionServiceClient,
2800+
attribute="count_tokens",
2801+
return_value=gca_count_tokens_response,
2802+
):
2803+
response = model.count_tokens("def reverse_string(s):")
2804+
2805+
assert response.total_tokens == _TEST_COUNT_TOKENS_RESPONSE["total_tokens"]
2806+
assert (
2807+
response.total_billable_characters
2808+
== _TEST_COUNT_TOKENS_RESPONSE["total_billable_characters"]
2809+
)
2810+
27742811
def test_code_completion(self):
27752812
"""Tests code completion with the code generation model."""
27762813
aiplatform.init(

vertexai/language_models/_language_models.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -2648,7 +2648,47 @@ async def predict_streaming_async(
26482648
yield _parse_text_generation_model_response(prediction_obj)
26492649

26502650

2651-
class _PreviewCodeGenerationModel(CodeGenerationModel, _TunableModelMixin):
2651+
class _CountTokensCodeGenerationMixin(_LanguageModel):
2652+
"""Mixin for code generation models that support the CountTokens API"""
2653+
2654+
def count_tokens(
2655+
self,
2656+
prefix: str,
2657+
*,
2658+
suffix: Optional[str] = None,
2659+
) -> CountTokensResponse:
2660+
"""Counts the tokens and billable characters for a given code generation prompt.
2661+
2662+
Note: this does not make a prediction request to the model, it only counts the tokens
2663+
in the request.
2664+
2665+
Args:
2666+
prefix (str): Code before the current point.
2667+
suffix (str): Code after the current point.
2668+
2669+
Returns:
2670+
A `CountTokensResponse` object that contains the number of tokens
2671+
in the text and the number of billable characters.
2672+
"""
2673+
prediction_request = {"prefix": prefix, "suffix": suffix}
2674+
2675+
count_tokens_response = self._endpoint._prediction_client.select_version(
2676+
"v1beta1"
2677+
).count_tokens(
2678+
endpoint=self._endpoint_name,
2679+
instances=[prediction_request],
2680+
)
2681+
2682+
return CountTokensResponse(
2683+
total_tokens=count_tokens_response.total_tokens,
2684+
total_billable_characters=count_tokens_response.total_billable_characters,
2685+
_count_tokens_response=count_tokens_response,
2686+
)
2687+
2688+
2689+
class _PreviewCodeGenerationModel(
2690+
CodeGenerationModel, _TunableModelMixin, _CountTokensCodeGenerationMixin
2691+
):
26522692
__name__ = "CodeGenerationModel"
26532693
__module__ = "vertexai.preview.language_models"
26542694

0 commit comments

Comments
 (0)