Skip to content

Commit 768af67

Browse files
jaycee-licopybara-github
authored andcommitted
feat: GenAI - Allowed callable functions to return values directly in Automatic Function Calling
PiperOrigin-RevId: 640574734
1 parent 945b9e4 commit 768af67

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

tests/unit/vertexai/test_generative_models.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def test_generate_content_vertex_rag_retriever(self):
950950
attribute="generate_content",
951951
new=mock_generate_content,
952952
)
953-
def test_chat_automatic_function_calling(self):
953+
def test_chat_automatic_function_calling_with_function_returning_dict(self):
954954
generative_models = preview_generative_models
955955
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
956956
get_current_weather
@@ -984,6 +984,51 @@ def test_chat_automatic_function_calling(self):
984984
chat2.send_message("What is the weather like in Boston?")
985985
assert err.match("Exceeded the maximum")
986986

987+
@mock.patch.object(
988+
target=prediction_service.PredictionServiceClient,
989+
attribute="generate_content",
990+
new=mock_generate_content,
991+
)
992+
def test_chat_automatic_function_calling_with_function_returning_value(self):
993+
# Define a new function that returns a value instead of a dict.
994+
def get_current_weather(location: str):
995+
"""Gets weather in the specified location.
996+
997+
Args:
998+
location: The location for which to get the weather.
999+
1000+
Returns:
1001+
The weather information as a str.
1002+
"""
1003+
if location == "Boston":
1004+
return "Super nice, but maybe a bit hot."
1005+
return "Unavailable"
1006+
1007+
generative_models = preview_generative_models
1008+
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
1009+
get_current_weather
1010+
)
1011+
weather_tool = generative_models.Tool(
1012+
function_declarations=[get_current_weather_func],
1013+
)
1014+
1015+
model = generative_models.GenerativeModel(
1016+
"gemini-pro",
1017+
# Specifying the tools once to avoid specifying them in every request
1018+
tools=[weather_tool],
1019+
)
1020+
afc_responder = generative_models.AutomaticFunctionCallingResponder(
1021+
max_automatic_function_calls=5,
1022+
)
1023+
chat = model.start_chat(responder=afc_responder)
1024+
1025+
response1 = chat.send_message("What is the weather like in Boston?")
1026+
assert response1.text.startswith("The weather in Boston is")
1027+
assert "nice" in response1.text
1028+
assert len(chat.history) == 4
1029+
assert chat.history[-3].parts[0].function_call
1030+
assert chat.history[-2].parts[0].function_response
1031+
9871032

9881033
EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER = {
9891034
"title": "get_current_weather",

vertexai/generative_models/_generative_models.py

+5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Classes for working with generative models."""
1616
# pylint: disable=bad-continuation, line-too-long, protected-access
1717

18+
from collections.abc import Mapping
1819
import copy
1920
import io
2021
import json
@@ -2422,6 +2423,10 @@ def respond_to_model_response(
24222423
# due to: AttributeError: type object 'MapComposite' has no attribute 'to_dict'
24232424
function_args = type(function_call).to_dict(function_call)["args"]
24242425
function_call_result = callable_function._function(**function_args)
2426+
if not isinstance(function_call_result, Mapping):
2427+
# If the function returns a single value, wrap it in the
2428+
# format that Part.from_function_response can accept.
2429+
function_call_result = {"result": function_call_result}
24252430
except Exception as ex:
24262431
raise RuntimeError(
24272432
f"""Error raised when calling function "{function_call.name}" as requested by the model."""

0 commit comments

Comments
 (0)