Skip to content

Commit 9a19545

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Switch Python generateContent to call Unary API endpoint
PiperOrigin-RevId: 604369375
1 parent 3f817f4 commit 9a19545

File tree

2 files changed

+58
-24
lines changed

2 files changed

+58
-24
lines changed

tests/unit/vertexai/test_generative_models.py

+56-6
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,56 @@ def mock_stream_generate_content(
163163
yield response
164164

165165

166+
def mock_generate_content(
167+
self,
168+
request: gapic_prediction_service_types.GenerateContentRequest,
169+
*,
170+
model: Optional[str] = None,
171+
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
172+
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
173+
is_continued_chat = len(request.contents) > 1
174+
has_tools = bool(request.tools)
175+
176+
if has_tools:
177+
has_function_response = any(
178+
"function_response" in content.parts[0] for content in request.contents
179+
)
180+
needs_function_call = not has_function_response
181+
if needs_function_call:
182+
response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT
183+
else:
184+
response_part_struct = _RESPONSE_AFTER_FUNCTION_CALL_PART_STRUCT
185+
elif is_continued_chat:
186+
response_part_struct = {"text": "Other planets may have different sky color."}
187+
else:
188+
response_part_struct = _RESPONSE_TEXT_PART_STRUCT
189+
190+
return gapic_prediction_service_types.GenerateContentResponse(
191+
candidates=[
192+
gapic_content_types.Candidate(
193+
index=0,
194+
content=gapic_content_types.Content(
195+
# Model currently does not identify itself
196+
# role="model",
197+
parts=[
198+
gapic_content_types.Part(response_part_struct),
199+
],
200+
),
201+
finish_reason=gapic_content_types.Candidate.FinishReason.STOP,
202+
safety_ratings=[
203+
gapic_content_types.SafetyRating(rating)
204+
for rating in _RESPONSE_SAFETY_RATINGS_STRUCT
205+
],
206+
citation_metadata=gapic_content_types.CitationMetadata(
207+
citations=[
208+
gapic_content_types.Citation(_RESPONSE_CITATION_STRUCT),
209+
]
210+
),
211+
),
212+
],
213+
)
214+
215+
166216
@pytest.mark.usefixtures("google_auth_mock")
167217
class TestGenerativeModels:
168218
"""Unit tests for the generative models."""
@@ -178,8 +228,8 @@ def teardown_method(self):
178228

179229
@mock.patch.object(
180230
target=prediction_service.PredictionServiceClient,
181-
attribute="stream_generate_content",
182-
new=mock_stream_generate_content,
231+
attribute="generate_content",
232+
new=mock_generate_content,
183233
)
184234
def test_generate_content(self):
185235
model = generative_models.GenerativeModel("gemini-pro")
@@ -212,8 +262,8 @@ def test_generate_content_streaming(self):
212262

213263
@mock.patch.object(
214264
target=prediction_service.PredictionServiceClient,
215-
attribute="stream_generate_content",
216-
new=mock_stream_generate_content,
265+
attribute="generate_content",
266+
new=mock_generate_content,
217267
)
218268
def test_chat_send_message(self):
219269
model = generative_models.GenerativeModel("gemini-pro")
@@ -225,8 +275,8 @@ def test_chat_send_message(self):
225275

226276
@mock.patch.object(
227277
target=prediction_service.PredictionServiceClient,
228-
attribute="stream_generate_content",
229-
new=mock_stream_generate_content,
278+
attribute="generate_content",
279+
new=mock_generate_content,
230280
)
231281
def test_chat_function_calling(self):
232282
get_current_weather_func = generative_models.FunctionDeclaration(

vertexai/generative_models/_generative_models.py

+2-18
Original file line numberDiff line numberDiff line change
@@ -431,15 +431,7 @@ def _generate_content(
431431
safety_settings=safety_settings,
432432
tools=tools,
433433
)
434-
# generate_content is not available
435-
# gapic_response = self._prediction_client.generate_content(request=request)
436-
gapic_response = None
437-
stream = self._prediction_client.stream_generate_content(request=request)
438-
for gapic_chunk in stream:
439-
if gapic_response:
440-
_append_gapic_response(gapic_response, gapic_chunk)
441-
else:
442-
gapic_response = gapic_chunk
434+
gapic_response = self._prediction_client.generate_content(request=request)
443435
return self._parse_response(gapic_response)
444436

445437
async def _generate_content_async(
@@ -473,17 +465,9 @@ async def _generate_content_async(
473465
safety_settings=safety_settings,
474466
tools=tools,
475467
)
476-
# generate_content is not available
477-
# gapic_response = await self._prediction_async_client.generate_content(request=request)
478-
gapic_response = None
479-
stream = await self._prediction_async_client.stream_generate_content(
468+
gapic_response = await self._prediction_async_client.generate_content(
480469
request=request
481470
)
482-
async for gapic_chunk in stream:
483-
if gapic_response:
484-
_append_gapic_response(gapic_response, gapic_chunk)
485-
else:
486-
gapic_response = gapic_chunk
487471
return self._parse_response(gapic_response)
488472

489473
def _generate_content_streaming(

0 commit comments

Comments
 (0)