Skip to content

Commit 613ce69

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: GenAI - Improved from_dict methods for content types (GenerationResponse, Candidate, Content, Part)
Workaround for issue in the proto-plus library: googleapis/proto-plus-python#424 Fixes #3194 PiperOrigin-RevId: 615334182
1 parent b30f5a6 commit 613ce69

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

tests/unit/vertexai/test_generative_models.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,29 @@ def test_chat_function_calling(self, generative_models: generative_models):
379379
[generative_models, preview_generative_models],
380380
)
381381
def test_conversion_methods(self, generative_models: generative_models):
382-
"""Tests the .to_dict, .from_dict and __repr__ methods"""
383-
model = generative_models.GenerativeModel("gemini-pro")
384-
response = model.generate_content("Why is sky blue?")
382+
"""Tests the .to_dict, .from_dict and __repr__ methods."""
383+
# Testing on a full chat conversation which includes function calling
384+
get_current_weather_func = generative_models.FunctionDeclaration(
385+
name="get_current_weather",
386+
description="Get the current weather in a given location",
387+
parameters=_REQUEST_FUNCTION_PARAMETER_SCHEMA_STRUCT,
388+
)
389+
weather_tool = generative_models.Tool(
390+
function_declarations=[get_current_weather_func],
391+
)
392+
393+
model = generative_models.GenerativeModel("gemini-pro", tools=[weather_tool])
394+
chat = model.start_chat()
395+
response = chat.send_message("What is the weather like in Boston?")
396+
chat.send_message(
397+
generative_models.Part.from_function_response(
398+
name="get_current_weather",
399+
response={
400+
"location": "Boston",
401+
"weather": "super nice",
402+
},
403+
),
404+
)
385405

386406
response_new = generative_models.GenerationResponse.from_dict(
387407
response.to_dict()
@@ -400,6 +420,12 @@ def test_conversion_methods(self, generative_models: generative_models):
400420
part_new = generative_models.Part.from_dict(part.to_dict())
401421
assert repr(part_new) == repr(part)
402422

423+
# Checking the history which contains different Part types
424+
for content in chat.history:
425+
for part in content.parts:
426+
part_new = generative_models.Part.from_dict(part.to_dict())
427+
assert repr(part_new) == repr(part)
428+
403429
@mock.patch.object(
404430
target=prediction_service.PredictionServiceClient,
405431
attribute="generate_content",

vertexai/generative_models/_generative_models.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from vertexai.language_models import (
4646
_language_models as tunable_models,
4747
)
48+
from google.protobuf import json_format
4849
import warnings
4950

5051
try:
@@ -1377,9 +1378,8 @@ def _from_gapic(
13771378

13781379
@classmethod
13791380
def from_dict(cls, response_dict: Dict[str, Any]) -> "GenerationResponse":
1380-
raw_response = gapic_prediction_service_types.GenerateContentResponse(
1381-
response_dict
1382-
)
1381+
raw_response = gapic_prediction_service_types.GenerateContentResponse()
1382+
json_format.ParseDict(response_dict, raw_response._pb)
13831383
return cls._from_gapic(raw_response=raw_response)
13841384

13851385
def to_dict(self) -> Dict[str, Any]:
@@ -1418,7 +1418,8 @@ def _from_gapic(cls, raw_candidate: gapic_content_types.Candidate) -> "Candidate
14181418

14191419
@classmethod
14201420
def from_dict(cls, candidate_dict: Dict[str, Any]) -> "Candidate":
1421-
raw_candidate = gapic_content_types.Candidate(candidate_dict)
1421+
raw_candidate = gapic_content_types.Candidate()
1422+
json_format.ParseDict(candidate_dict, raw_candidate._pb)
14221423
return cls._from_gapic(raw_candidate=raw_candidate)
14231424

14241425
def to_dict(self) -> Dict[str, Any]:
@@ -1497,7 +1498,8 @@ def _from_gapic(cls, raw_content: gapic_content_types.Content) -> "Content":
14971498

14981499
@classmethod
14991500
def from_dict(cls, content_dict: Dict[str, Any]) -> "Content":
1500-
raw_content = gapic_content_types.Content(content_dict)
1501+
raw_content = gapic_content_types.Content()
1502+
json_format.ParseDict(content_dict, raw_content._pb)
15011503
return cls._from_gapic(raw_content=raw_content)
15021504

15031505
def to_dict(self) -> Dict[str, Any]:
@@ -1563,7 +1565,8 @@ def _from_gapic(cls, raw_part: gapic_content_types.Part) -> "Part":
15631565

15641566
@classmethod
15651567
def from_dict(cls, part_dict: Dict[str, Any]) -> "Part":
1566-
raw_part = gapic_content_types.Part(part_dict)
1568+
raw_part = gapic_content_types.Part()
1569+
json_format.ParseDict(part_dict, raw_part._pb)
15671570
return cls._from_gapic(raw_part=raw_part)
15681571

15691572
def __repr__(self):

0 commit comments

Comments
 (0)