@@ -163,6 +163,56 @@ def mock_stream_generate_content(
163
163
yield response
164
164
165
165
166
+ def mock_generate_content (
167
+ self ,
168
+ request : gapic_prediction_service_types .GenerateContentRequest ,
169
+ * ,
170
+ model : Optional [str ] = None ,
171
+ contents : Optional [MutableSequence [gapic_content_types .Content ]] = None ,
172
+ ) -> Iterable [gapic_prediction_service_types .GenerateContentResponse ]:
173
+ is_continued_chat = len (request .contents ) > 1
174
+ has_tools = bool (request .tools )
175
+
176
+ if has_tools :
177
+ has_function_response = any (
178
+ "function_response" in content .parts [0 ] for content in request .contents
179
+ )
180
+ needs_function_call = not has_function_response
181
+ if needs_function_call :
182
+ response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT
183
+ else :
184
+ response_part_struct = _RESPONSE_AFTER_FUNCTION_CALL_PART_STRUCT
185
+ elif is_continued_chat :
186
+ response_part_struct = {"text" : "Other planets may have different sky color." }
187
+ else :
188
+ response_part_struct = _RESPONSE_TEXT_PART_STRUCT
189
+
190
+ return gapic_prediction_service_types .GenerateContentResponse (
191
+ candidates = [
192
+ gapic_content_types .Candidate (
193
+ index = 0 ,
194
+ content = gapic_content_types .Content (
195
+ # Model currently does not identify itself
196
+ # role="model",
197
+ parts = [
198
+ gapic_content_types .Part (response_part_struct ),
199
+ ],
200
+ ),
201
+ finish_reason = gapic_content_types .Candidate .FinishReason .STOP ,
202
+ safety_ratings = [
203
+ gapic_content_types .SafetyRating (rating )
204
+ for rating in _RESPONSE_SAFETY_RATINGS_STRUCT
205
+ ],
206
+ citation_metadata = gapic_content_types .CitationMetadata (
207
+ citations = [
208
+ gapic_content_types .Citation (_RESPONSE_CITATION_STRUCT ),
209
+ ]
210
+ ),
211
+ ),
212
+ ],
213
+ )
214
+
215
+
166
216
@pytest .mark .usefixtures ("google_auth_mock" )
167
217
class TestGenerativeModels :
168
218
"""Unit tests for the generative models."""
@@ -178,8 +228,8 @@ def teardown_method(self):
178
228
179
229
@mock .patch .object (
180
230
target = prediction_service .PredictionServiceClient ,
181
- attribute = "stream_generate_content " ,
182
- new = mock_stream_generate_content ,
231
+ attribute = "generate_content " ,
232
+ new = mock_generate_content ,
183
233
)
184
234
def test_generate_content (self ):
185
235
model = generative_models .GenerativeModel ("gemini-pro" )
@@ -212,8 +262,8 @@ def test_generate_content_streaming(self):
212
262
213
263
@mock .patch .object (
214
264
target = prediction_service .PredictionServiceClient ,
215
- attribute = "stream_generate_content " ,
216
- new = mock_stream_generate_content ,
265
+ attribute = "generate_content " ,
266
+ new = mock_generate_content ,
217
267
)
218
268
def test_chat_send_message (self ):
219
269
model = generative_models .GenerativeModel ("gemini-pro" )
@@ -225,8 +275,8 @@ def test_chat_send_message(self):
225
275
226
276
@mock .patch .object (
227
277
target = prediction_service .PredictionServiceClient ,
228
- attribute = "stream_generate_content " ,
229
- new = mock_stream_generate_content ,
278
+ attribute = "generate_content " ,
279
+ new = mock_generate_content ,
230
280
)
231
281
def test_chat_function_calling (self ):
232
282
get_current_weather_func = generative_models .FunctionDeclaration (
0 commit comments