Skip to content

Commit d6490ff

Browse files
happy-qiaocopybara-github
authored andcommitted
feat: GenAI - Added function_calls shortcut property to Candidate class.
PiperOrigin-RevId: 613984044
1 parent a9010aa commit d6490ff

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

tests/system/vertexai/test_generative_models.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,12 @@ def test_generate_content_function_calling(self):
280280
== "get_current_weather"
281281
)
282282

283-
assert response.candidates[0].content.parts[0].function_call.args["location"]
283+
assert response.candidates[0].function_calls[0].args["location"]
284+
assert len(response.candidates[0].function_calls) == 1
285+
assert (
286+
response.candidates[0].function_calls[0]
287+
== response.candidates[0].content.parts[0].function_call
288+
)
284289

285290
# fake api_response data
286291
api_response = {
@@ -309,6 +314,7 @@ def test_generate_content_function_calling(self):
309314
tools=[weather_tool],
310315
)
311316
assert response
317+
assert len(response.candidates[0].function_calls) == 0
312318

313319
# Get the model summary response
314320
summary = response.candidates[0].content.parts[0].text

tests/unit/vertexai/test_generative_models.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
import vertexai
2424
from google.cloud.aiplatform import initializer
2525
from vertexai import generative_models
26-
from vertexai.preview import generative_models as preview_generative_models
26+
from vertexai.preview import (
27+
generative_models as preview_generative_models,
28+
)
2729
from vertexai.generative_models._generative_models import (
2830
prediction_service,
2931
gapic_prediction_service_types,
@@ -352,6 +354,10 @@ def test_chat_function_calling(self, generative_models: generative_models):
352354
response1.candidates[0].content.parts[0].function_call.name
353355
== "get_current_weather"
354356
)
357+
assert [
358+
function_call.name
359+
for function_call in response1.candidates[0].function_calls
360+
] == ["get_current_weather"]
355361
response2 = chat.send_message(
356362
generative_models.Part.from_function_response(
357363
name="get_current_weather",
@@ -361,6 +367,7 @@ def test_chat_function_calling(self, generative_models: generative_models):
361367
),
362368
)
363369
assert response2.text == "The weather in Boston is super nice!"
370+
assert len(response2.candidates[0].function_calls) == 0
364371

365372
@mock.patch.object(
366373
target=prediction_service.PredictionServiceClient,

vertexai/generative_models/_generative_models.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -1217,9 +1217,7 @@ def from_retrieval(
12171217
cls,
12181218
retrieval: "Retrieval",
12191219
):
1220-
raw_tool = gapic_tool_types.Tool(
1221-
retrieval=retrieval._raw_retrieval
1222-
)
1220+
raw_tool = gapic_tool_types.Tool(retrieval=retrieval._raw_retrieval)
12231221
return cls._from_gapic(raw_tool=raw_tool)
12241222

12251223
@classmethod
@@ -1460,6 +1458,16 @@ def citation_metadata(self) -> gapic_content_types.CitationMetadata:
14601458
def text(self) -> str:
14611459
return self.content.text
14621460

1461+
@property
1462+
def function_calls(self) -> Sequence[gapic_tool_types.FunctionCall]:
1463+
if not self.content or not self.content.parts:
1464+
return []
1465+
return [
1466+
part.function_call
1467+
for part in self.content.parts
1468+
if part and part.function_call
1469+
]
1470+
14631471

14641472
class Content:
14651473
r"""The multi-part content of a message.

0 commit comments

Comments
 (0)