Skip to content

Commit f0086df

Browse files
jaycee-licopybara-github
authored andcommitted
fix: GenAI - Capture content blocked case when validating responses
PiperOrigin-RevId: 616988650
1 parent 2dc7f41 commit f0086df

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

tests/unit/vertexai/test_generative_models.py

+39
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,20 @@ def mock_generate_content(
133133
)
134134
return response
135135

136+
should_block = (
137+
last_message_part.text
138+
and "Please block with block_reason=OTHER" in last_message_part.text
139+
)
140+
if should_block:
141+
response = gapic_prediction_service_types.GenerateContentResponse(
142+
candidates=[],
143+
prompt_feedback=gapic_prediction_service_types.GenerateContentResponse.PromptFeedback(
144+
block_reason=gapic_prediction_service_types.GenerateContentResponse.PromptFeedback.BlockedReason.OTHER,
145+
block_reason_message="Blocked for testing",
146+
),
147+
)
148+
return response
149+
136150
is_continued_chat = len(request.contents) > 1
137151
has_retrieval = any(
138152
tool.retrieval or tool.google_search_retrieval for tool in request.tools
@@ -349,6 +363,31 @@ def test_chat_send_message_response_validation_errors(
349363
# Checking that history did not get updated
350364
assert len(chat.history) == 2
351365

366+
@mock.patch.object(
367+
target=prediction_service.PredictionServiceClient,
368+
attribute="generate_content",
369+
new=mock_generate_content,
370+
)
371+
@pytest.mark.parametrize(
372+
"generative_models",
373+
[generative_models, preview_generative_models],
374+
)
375+
def test_chat_send_message_response_blocked_errors(
376+
self, generative_models: generative_models
377+
):
378+
model = generative_models.GenerativeModel("gemini-pro")
379+
chat = model.start_chat()
380+
response1 = chat.send_message("Why is sky blue?")
381+
assert response1.text
382+
assert len(chat.history) == 2
383+
384+
with pytest.raises(generative_models.ResponseValidationError) as e:
385+
chat.send_message("Please block with block_reason=OTHER.")
386+
387+
assert e.match("Blocked for testing")
388+
# Checking that history did not get updated
389+
assert len(chat.history) == 2
390+
352391
@mock.patch.object(
353392
target=prediction_service.PredictionServiceClient,
354393
attribute="generate_content",

vertexai/generative_models/_generative_models.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -640,13 +640,23 @@ def _validate_response(
640640
request_contents: Optional[List["Content"]] = None,
641641
response_chunks: Optional[List["GenerationResponse"]] = None,
642642
) -> None:
643-
candidate = response.candidates[0]
644-
if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS:
645-
message = (
646-
"The model response did not completed successfully.\n"
647-
f"Finish reason: {candidate.finish_reason}.\n"
648-
f"Finish message: {candidate.finish_message}.\n"
649-
f"Safety ratings: {candidate.safety_ratings}.\n"
643+
message = ""
644+
if not response.candidates:
645+
message += (
646+
f"The model response was blocked due to {response._raw_response.prompt_feedback.block_reason}.\n"
647+
f"Blocke reason message: {response._raw_response.prompt_feedback.block_reason_message}.\n"
648+
)
649+
else:
650+
candidate = response.candidates[0]
651+
if candidate.finish_reason not in _SUCCESSFUL_FINISH_REASONS:
652+
message = (
653+
"The model response did not completed successfully.\n"
654+
f"Finish reason: {candidate.finish_reason}.\n"
655+
f"Finish message: {candidate.finish_message}.\n"
656+
f"Safety ratings: {candidate.safety_ratings}.\n"
657+
)
658+
if message:
659+
message += (
650660
"To protect the integrity of the chat session, the request and response were not added to chat history.\n"
651661
"To skip the response validation, specify `model.start_chat(response_validation=False)`.\n"
652662
"Note that letting blocked or otherwise incomplete responses into chat history might lead to future interactions being blocked by the service."

0 commit comments

Comments
 (0)