Skip to content

Commit fa35b91

Browse files
Ark-kuncopybara-github
authored andcommitted
fix: GenAI - Workaround for streaming when content role is missing in service responses
PiperOrigin-RevId: 623296418
1 parent 40b728b commit fa35b91

File tree

2 files changed

+79
-6
lines changed

2 files changed

+79
-6
lines changed

tests/unit/vertexai/test_generative_models.py

+73-4
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def mock_generate_content(
119119
*,
120120
model: Optional[str] = None,
121121
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
122-
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
122+
) -> gapic_prediction_service_types.GenerateContentResponse:
123123
last_message_part = request.contents[-1].parts[0]
124124
should_fail = last_message_part.text and "Please fail" in last_message_part.text
125125
if should_fail:
@@ -203,8 +203,7 @@ def mock_generate_content(
203203
gapic_content_types.Candidate(
204204
index=0,
205205
content=gapic_content_types.Content(
206-
# Model currently does not identify itself
207-
# role="model",
206+
role="model",
208207
parts=[
209208
gapic_content_types.Part(response_part_struct),
210209
],
@@ -240,6 +239,13 @@ def mock_generate_content(
240239
),
241240
],
242241
)
242+
243+
if "Please block response with finish_reason=OTHER" in (
244+
last_message_part.text or ""
245+
):
246+
finish_reason = gapic_content_types.Candidate.FinishReason.OTHER
247+
response.candidates[0].finish_reason = finish_reason
248+
243249
return response
244250

245251

@@ -250,9 +256,32 @@ def mock_stream_generate_content(
250256
model: Optional[str] = None,
251257
contents: Optional[MutableSequence[gapic_content_types.Content]] = None,
252258
) -> Iterable[gapic_prediction_service_types.GenerateContentResponse]:
253-
yield mock_generate_content(
259+
response = mock_generate_content(
254260
self=self, request=request, model=model, contents=contents
255261
)
262+
# When a streaming response gets blocked, the last chunk has no content.
263+
# Creating such last chunk.
264+
blocked_chunk = None
265+
candidate_0 = response.candidates[0] if response.candidates else None
266+
if candidate_0 and candidate_0.finish_reason not in (
267+
gapic_content_types.Candidate.FinishReason.STOP,
268+
gapic_content_types.Candidate.FinishReason.MAX_TOKENS,
269+
):
270+
blocked_chunk = gapic_prediction_service_types.GenerateContentResponse(
271+
candidates=[
272+
gapic_content_types.Candidate(
273+
index=0,
274+
finish_reason=candidate_0.finish_reason,
275+
finish_message=candidate_0.finish_message,
276+
safety_ratings=candidate_0.safety_ratings,
277+
)
278+
]
279+
)
280+
candidate_0.finish_reason = None
281+
candidate_0.finish_message = None
282+
yield response
283+
if blocked_chunk:
284+
yield blocked_chunk
256285

257286

258287
def get_current_weather(location: str, unit: Optional[str] = "centigrade"):
@@ -407,6 +436,25 @@ def test_chat_send_message(self, generative_models: generative_models):
407436
response2 = chat.send_message("Is sky blue on other planets?")
408437
assert response2.text
409438

439+
@mock.patch.object(
440+
target=prediction_service.PredictionServiceClient,
441+
attribute="stream_generate_content",
442+
new=mock_stream_generate_content,
443+
)
444+
@pytest.mark.parametrize(
445+
"generative_models",
446+
[generative_models, preview_generative_models],
447+
)
448+
def test_chat_send_message_streaming(self, generative_models: generative_models):
449+
model = generative_models.GenerativeModel("gemini-pro")
450+
chat = model.start_chat()
451+
stream1 = chat.send_message("Why is sky blue?", stream=True)
452+
for chunk in stream1:
453+
assert chunk.candidates
454+
stream2 = chat.send_message("Is sky blue on other planets?", stream=True)
455+
for chunk in stream2:
456+
assert chunk.candidates
457+
410458
@mock.patch.object(
411459
target=prediction_service.PredictionServiceClient,
412460
attribute="generate_content",
@@ -455,6 +503,27 @@ def test_chat_send_message_response_blocked_errors(
455503
# Checking that history did not get updated
456504
assert len(chat.history) == 2
457505

506+
@mock.patch.object(
507+
target=prediction_service.PredictionServiceClient,
508+
attribute="generate_content",
509+
new=mock_generate_content,
510+
)
511+
@pytest.mark.parametrize(
512+
"generative_models",
513+
[generative_models, preview_generative_models],
514+
)
515+
def test_chat_send_message_response_candidate_blocked_error(
516+
self, generative_models: generative_models
517+
):
518+
model = generative_models.GenerativeModel("gemini-pro")
519+
chat = model.start_chat()
520+
521+
with pytest.raises(generative_models.ResponseValidationError):
522+
chat.send_message("Please block response with finish_reason=OTHER.")
523+
524+
# Checking that history did not get updated
525+
assert not chat.history
526+
458527
@mock.patch.object(
459528
target=prediction_service.PredictionServiceClient,
460529
attribute="generate_content",

vertexai/generative_models/_generative_models.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -2113,7 +2113,9 @@ def _append_gapic_candidate(
21132113
f"Incorrect candidate indexes: {base_candidate.index} != {new_candidate.index}"
21142114
)
21152115

2116-
_append_gapic_content(base_candidate.content, new_candidate.content)
2116+
# Only merge content if it exists.
2117+
if "content" in new_candidate:
2118+
_append_gapic_content(base_candidate.content, new_candidate.content)
21172119

21182120
# For these attributes, the last value wins
21192121
if new_candidate.finish_reason:
@@ -2130,7 +2132,9 @@ def _append_gapic_content(
21302132
base_content: gapic_content_types.Content,
21312133
new_content: gapic_content_types.Content,
21322134
):
2133-
if base_content.role != new_content.role:
2135+
# Handling empty role is a workaround for a case when service returns
2136+
# some chunks with missing role field (e.g. when response is blocked).
2137+
if new_content.role and base_content.role != new_content.role:
21342138
raise ValueError(
21352139
f"Content roles do not match: {base_content.role} != {new_content.role}"
21362140
)

0 commit comments

Comments
 (0)