Skip to content

Commit a78748e

Browse files
Ark-kuncopybara-github
authored andcommitted
feat: GenAI - Added to_dict() methods to response and content classes
Also fixed couple of existing methods that were broken. PiperOrigin-RevId: 607215896
1 parent 94f7cd9 commit a78748e

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

tests/unit/vertexai/test_generative_models.py

+31
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,37 @@ def test_chat_function_calling(self, generative_models: generative_models):
362362
)
363363
assert response2.text == "The weather in Boston is super nice!"
364364

365+
@mock.patch.object(
366+
target=prediction_service.PredictionServiceClient,
367+
attribute="generate_content",
368+
new=mock_generate_content,
369+
)
370+
@pytest.mark.parametrize(
371+
"generative_models",
372+
[generative_models, preview_generative_models],
373+
)
374+
def test_conversion_methods(self, generative_models: generative_models):
375+
"""Tests the .to_dict, .from_dict and __repr__ methods"""
376+
model = generative_models.GenerativeModel("gemini-pro")
377+
response = model.generate_content("Why is sky blue?")
378+
379+
response_new = generative_models.GenerationResponse.from_dict(
380+
response.to_dict()
381+
)
382+
assert repr(response_new) == repr(response)
383+
384+
for candidate in response.candidates:
385+
candidate_new = generative_models.Candidate.from_dict(candidate.to_dict())
386+
assert repr(candidate_new) == repr(candidate)
387+
388+
content = candidate.content
389+
content_new = generative_models.Content.from_dict(content.to_dict())
390+
assert repr(content_new) == repr(content)
391+
392+
for part in content.parts:
393+
part_new = generative_models.Part.from_dict(part.to_dict())
394+
assert repr(part_new) == repr(part)
395+
365396
@mock.patch.object(
366397
target=prediction_service.PredictionServiceClient,
367398
attribute="generate_content",

vertexai/generative_models/_generative_models.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -1137,8 +1137,11 @@ def from_dict(cls, generation_config_dict: Dict[str, Any]) -> "GenerationConfig"
11371137
)
11381138
return cls._from_gapic(raw_generation_config=raw_generation_config)
11391139

1140+
def to_dict(self) -> Dict[str, Any]:
1141+
return type(self._raw_generation_config).to_dict(self._raw_generation_config)
1142+
11401143
def __repr__(self):
1141-
return self._raw_tool.__repr__()
1144+
return self._raw_generation_config.__repr__()
11421145

11431146

11441147
class Tool:
@@ -1249,6 +1252,9 @@ def from_dict(cls, tool_dict: Dict[str, Any]) -> "Tool":
12491252
raw_tool = gapic_tool_types.Tool(tool_dict)
12501253
return cls._from_gapic(raw_tool=raw_tool)
12511254

1255+
def to_dict(self) -> Dict[str, Any]:
1256+
return type(self._raw_tool).to_dict(self._raw_tool)
1257+
12521258
def __repr__(self):
12531259
return self._raw_tool.__repr__()
12541260

@@ -1378,6 +1384,9 @@ def from_dict(cls, response_dict: Dict[str, Any]) -> "GenerationResponse":
13781384
)
13791385
return cls._from_gapic(raw_response=raw_response)
13801386

1387+
def to_dict(self) -> Dict[str, Any]:
1388+
return type(self._raw_response).to_dict(self._raw_response)
1389+
13811390
def __repr__(self):
13821391
return self._raw_response.__repr__()
13831392

@@ -1414,6 +1423,9 @@ def from_dict(cls, candidate_dict: Dict[str, Any]) -> "Candidate":
14141423
raw_candidate = gapic_content_types.Candidate(candidate_dict)
14151424
return cls._from_gapic(raw_candidate=raw_candidate)
14161425

1426+
def to_dict(self) -> Dict[str, Any]:
1427+
return type(self._raw_candidate).to_dict(self._raw_candidate)
1428+
14171429
def __repr__(self):
14181430
return self._raw_candidate.__repr__()
14191431

@@ -1480,6 +1492,9 @@ def from_dict(cls, content_dict: Dict[str, Any]) -> "Content":
14801492
raw_content = gapic_content_types.Content(content_dict)
14811493
return cls._from_gapic(raw_content=raw_content)
14821494

1495+
def to_dict(self) -> Dict[str, Any]:
1496+
return type(self._raw_content).to_dict(self._raw_content)
1497+
14831498
def __repr__(self):
14841499
return self._raw_content.__repr__()
14851500

@@ -1584,7 +1599,7 @@ def from_function_response(name: str, response: Dict[str, Any]) -> "Part":
15841599
)
15851600

15861601
def to_dict(self) -> Dict[str, Any]:
1587-
return self._raw_part.to_dict()
1602+
return type(self._raw_part).to_dict(self._raw_part)
15881603

15891604
@property
15901605
def text(self) -> str:

0 commit comments

Comments
 (0)