Skip to content

Commit 50fca69

Browse files
happy-qiaocopybara-github
authored andcommitted
feat: GenAI - Added system_instruction and tools support to GenerativeModel.count_tokens
PiperOrigin-RevId: 669000052
1 parent 20f2cad commit 50fca69

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

tests/system/vertexai/test_generative_models.py

+33
Original file line numberDiff line numberDiff line change
@@ -505,3 +505,36 @@ def test_compute_tokens_from_text(self, api_endpoint_env_name):
505505
assert token_info.role
506506
# Lightly validate that the tokens are not Base64 encoded
507507
assert b"=" not in token_info.tokens
508+
509+
def test_count_tokens_from_text(self):
510+
plain_model = generative_models.GenerativeModel(GEMINI_MODEL_NAME)
511+
model = generative_models.GenerativeModel(
512+
GEMINI_MODEL_NAME, system_instruction=["You are a chatbot."]
513+
)
514+
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
515+
get_current_weather
516+
)
517+
weather_tool = generative_models.Tool(
518+
function_declarations=[get_current_weather_func],
519+
)
520+
content = ["Why is sky blue?", "Explain it like I'm 5."]
521+
522+
response_without_si = plain_model.count_tokens(content)
523+
response_with_si = model.count_tokens(content)
524+
response_with_si_and_tool = model.count_tokens(
525+
content,
526+
tools=[weather_tool],
527+
)
528+
529+
# system instruction + user prompt
530+
assert response_with_si.total_tokens > response_without_si.total_tokens
531+
assert (
532+
response_with_si.total_billable_characters
533+
> response_without_si.total_billable_characters
534+
)
535+
# system instruction + user prompt + tool
536+
assert response_with_si_and_tool.total_tokens > response_with_si.total_tokens
537+
assert (
538+
response_with_si_and_tool.total_billable_characters
539+
> response_with_si.total_billable_characters
540+
)

vertexai/generative_models/_generative_models.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ async def async_generator():
824824
return async_generator()
825825

826826
def count_tokens(
827-
self, contents: ContentsType
827+
self, contents: ContentsType, *, tools: Optional[List["Tool"]] = None
828828
) -> gapic_prediction_service_types.CountTokensResponse:
829829
"""Counts tokens.
830830
@@ -836,22 +836,32 @@ def count_tokens(
836836
* str, Image, Part,
837837
* List[Union[str, Image, Part]],
838838
* List[Content]
839+
tools: A list of tools (functions) that the model can try calling.
839840
840841
Returns:
841842
A CountTokensResponse object that has the following attributes:
842843
total_tokens: The total number of tokens counted across all instances from the request.
843844
total_billable_characters: The total number of billable characters counted across all instances from the request.
844845
"""
846+
request = self._prepare_request(
847+
contents=contents,
848+
tools=tools,
849+
)
845850
return self._prediction_client.count_tokens(
846851
request=gapic_prediction_service_types.CountTokensRequest(
847852
endpoint=self._prediction_resource_name,
848853
model=self._prediction_resource_name,
849-
contents=self._prepare_request(contents=contents).contents,
854+
contents=request.contents,
855+
system_instruction=request.system_instruction,
856+
tools=request.tools,
850857
)
851858
)
852859

853860
async def count_tokens_async(
854-
self, contents: ContentsType
861+
self,
862+
contents: ContentsType,
863+
*,
864+
tools: Optional[List["Tool"]] = None,
855865
) -> gapic_prediction_service_types.CountTokensResponse:
856866
"""Counts tokens asynchronously.
857867
@@ -863,17 +873,24 @@ async def count_tokens_async(
863873
* str, Image, Part,
864874
* List[Union[str, Image, Part]],
865875
* List[Content]
876+
tools: A list of tools (functions) that the model can try calling.
866877
867878
Returns:
868879
And awaitable for a CountTokensResponse object that has the following attributes:
869880
total_tokens: The total number of tokens counted across all instances from the request.
870881
total_billable_characters: The total number of billable characters counted across all instances from the request.
871882
"""
883+
request = self._prepare_request(
884+
contents=contents,
885+
tools=tools,
886+
)
872887
return await self._prediction_async_client.count_tokens(
873888
request=gapic_prediction_service_types.CountTokensRequest(
874889
endpoint=self._prediction_resource_name,
875890
model=self._prediction_resource_name,
876-
contents=self._prepare_request(contents=contents).contents,
891+
contents=request.contents,
892+
system_instruction=request.system_instruction,
893+
tools=request.tools,
877894
)
878895
)
879896

0 commit comments

Comments
 (0)