@@ -119,7 +119,7 @@ def mock_generate_content(
119
119
* ,
120
120
model : Optional [str ] = None ,
121
121
contents : Optional [MutableSequence [gapic_content_types .Content ]] = None ,
122
- ) -> Iterable [ gapic_prediction_service_types .GenerateContentResponse ] :
122
+ ) -> gapic_prediction_service_types .GenerateContentResponse :
123
123
last_message_part = request .contents [- 1 ].parts [0 ]
124
124
should_fail = last_message_part .text and "Please fail" in last_message_part .text
125
125
if should_fail :
@@ -203,8 +203,7 @@ def mock_generate_content(
203
203
gapic_content_types .Candidate (
204
204
index = 0 ,
205
205
content = gapic_content_types .Content (
206
- # Model currently does not identify itself
207
- # role="model",
206
+ role = "model" ,
208
207
parts = [
209
208
gapic_content_types .Part (response_part_struct ),
210
209
],
@@ -240,6 +239,13 @@ def mock_generate_content(
240
239
),
241
240
],
242
241
)
242
+
243
+ if "Please block response with finish_reason=OTHER" in (
244
+ last_message_part .text or ""
245
+ ):
246
+ finish_reason = gapic_content_types .Candidate .FinishReason .OTHER
247
+ response .candidates [0 ].finish_reason = finish_reason
248
+
243
249
return response
244
250
245
251
@@ -250,9 +256,32 @@ def mock_stream_generate_content(
250
256
model : Optional [str ] = None ,
251
257
contents : Optional [MutableSequence [gapic_content_types .Content ]] = None ,
252
258
) -> Iterable [gapic_prediction_service_types .GenerateContentResponse ]:
253
- yield mock_generate_content (
259
+ response = mock_generate_content (
254
260
self = self , request = request , model = model , contents = contents
255
261
)
262
+ # When a streaming response gets blocked, the last chunk has no content.
263
+ # Creating such last chunk.
264
+ blocked_chunk = None
265
+ candidate_0 = response .candidates [0 ] if response .candidates else None
266
+ if candidate_0 and candidate_0 .finish_reason not in (
267
+ gapic_content_types .Candidate .FinishReason .STOP ,
268
+ gapic_content_types .Candidate .FinishReason .MAX_TOKENS ,
269
+ ):
270
+ blocked_chunk = gapic_prediction_service_types .GenerateContentResponse (
271
+ candidates = [
272
+ gapic_content_types .Candidate (
273
+ index = 0 ,
274
+ finish_reason = candidate_0 .finish_reason ,
275
+ finish_message = candidate_0 .finish_message ,
276
+ safety_ratings = candidate_0 .safety_ratings ,
277
+ )
278
+ ]
279
+ )
280
+ candidate_0 .finish_reason = None
281
+ candidate_0 .finish_message = None
282
+ yield response
283
+ if blocked_chunk :
284
+ yield blocked_chunk
256
285
257
286
258
287
def get_current_weather (location : str , unit : Optional [str ] = "centigrade" ):
@@ -407,6 +436,25 @@ def test_chat_send_message(self, generative_models: generative_models):
407
436
response2 = chat .send_message ("Is sky blue on other planets?" )
408
437
assert response2 .text
409
438
439
+ @mock .patch .object (
440
+ target = prediction_service .PredictionServiceClient ,
441
+ attribute = "stream_generate_content" ,
442
+ new = mock_stream_generate_content ,
443
+ )
444
+ @pytest .mark .parametrize (
445
+ "generative_models" ,
446
+ [generative_models , preview_generative_models ],
447
+ )
448
+ def test_chat_send_message_streaming (self , generative_models : generative_models ):
449
+ model = generative_models .GenerativeModel ("gemini-pro" )
450
+ chat = model .start_chat ()
451
+ stream1 = chat .send_message ("Why is sky blue?" , stream = True )
452
+ for chunk in stream1 :
453
+ assert chunk .candidates
454
+ stream2 = chat .send_message ("Is sky blue on other planets?" , stream = True )
455
+ for chunk in stream2 :
456
+ assert chunk .candidates
457
+
410
458
@mock .patch .object (
411
459
target = prediction_service .PredictionServiceClient ,
412
460
attribute = "generate_content" ,
@@ -455,6 +503,27 @@ def test_chat_send_message_response_blocked_errors(
455
503
# Checking that history did not get updated
456
504
assert len (chat .history ) == 2
457
505
506
+ @mock .patch .object (
507
+ target = prediction_service .PredictionServiceClient ,
508
+ attribute = "generate_content" ,
509
+ new = mock_generate_content ,
510
+ )
511
+ @pytest .mark .parametrize (
512
+ "generative_models" ,
513
+ [generative_models , preview_generative_models ],
514
+ )
515
+ def test_chat_send_message_response_candidate_blocked_error (
516
+ self , generative_models : generative_models
517
+ ):
518
+ model = generative_models .GenerativeModel ("gemini-pro" )
519
+ chat = model .start_chat ()
520
+
521
+ with pytest .raises (generative_models .ResponseValidationError ):
522
+ chat .send_message ("Please block response with finish_reason=OTHER." )
523
+
524
+ # Checking that history did not get updated
525
+ assert not chat .history
526
+
458
527
@mock .patch .object (
459
528
target = prediction_service .PredictionServiceClient ,
460
529
attribute = "generate_content" ,
0 commit comments