@@ -174,14 +174,20 @@ def get_usage(response) -> Dict:
174
174
}
175
175
176
176
def create (self , params : Dict ) -> ChatCompletion :
177
+
177
178
if self .use_vertexai :
178
179
self ._initialize_vertexai (** params )
179
180
else :
180
181
assert ("project_id" not in params ) and (
181
182
"location" not in params
182
183
), "Google Cloud project and compute location cannot be set when using an API Key!"
183
184
model_name = params .get ("model" , "gemini-pro" )
184
- if not model_name :
185
+
186
+ if model_name == "gemini-pro-vision" :
187
+ raise ValueError (
188
+ "Gemini 1.0 Pro vision ('gemini-pro-vision') has been deprecated, please consider switching to a different model, for example 'gemini-1.5-flash'."
189
+ )
190
+ elif not model_name :
185
191
raise ValueError (
186
192
"Please provide a model name for the Gemini Client. "
187
193
"You can configure it in the OAI Config List file. "
@@ -197,7 +203,7 @@ def create(self, params: Dict) -> ChatCompletion:
197
203
if "tools" in params :
198
204
tools = self ._tools_to_gemini_tools (params ["tools" ])
199
205
else :
200
- tools = []
206
+ tools = None
201
207
202
208
generation_config = {
203
209
gemini_term : params [autogen_term ]
@@ -224,117 +230,81 @@ def create(self, params: Dict) -> ChatCompletion:
224
230
# Maps the function call ids to function names so we can inject it into FunctionResponse messages
225
231
self .tool_call_function_map : Dict [str , str ] = {}
226
232
227
- if "vision" not in model_name :
228
- # A. create and call the chat model.
229
- gemini_messages = self ._oai_messages_to_gemini_messages (messages )
230
- if self .use_vertexai :
231
- model = GenerativeModel (
232
- model_name ,
233
- generation_config = generation_config ,
234
- safety_settings = safety_settings ,
235
- system_instruction = system_instruction ,
236
- tools = tools ,
237
- )
238
-
239
- chat = model .start_chat (history = gemini_messages [:- 1 ], response_validation = response_validation )
240
- else :
241
- model = genai .GenerativeModel (
242
- model_name ,
243
- generation_config = generation_config ,
244
- safety_settings = safety_settings ,
245
- system_instruction = system_instruction ,
246
- tools = tools ,
247
- )
248
-
249
- genai .configure (api_key = self .api_key )
250
- chat = model .start_chat (history = gemini_messages [:- 1 ])
233
+ # A. create and call the chat model.
234
+ gemini_messages = self ._oai_messages_to_gemini_messages (messages )
235
+ if self .use_vertexai :
236
+ model = GenerativeModel (
237
+ model_name ,
238
+ generation_config = generation_config ,
239
+ safety_settings = safety_settings ,
240
+ system_instruction = system_instruction ,
241
+ tools = tools ,
242
+ )
251
243
252
- response = chat .send_message (gemini_messages [- 1 ].parts , stream = stream , safety_settings = safety_settings )
244
+ chat = model .start_chat (history = gemini_messages [:- 1 ], response_validation = response_validation )
245
+ else :
246
+ model = genai .GenerativeModel (
247
+ model_name ,
248
+ generation_config = generation_config ,
249
+ safety_settings = safety_settings ,
250
+ system_instruction = system_instruction ,
251
+ tools = tools ,
252
+ )
253
253
254
- # Extract text and tools from response
255
- ans = ""
256
- random_id = random .randint (0 , 10000 )
257
- prev_function_calls = []
258
- for part in response .parts :
259
-
260
- # Function calls
261
- if fn_call := part .function_call :
262
-
263
- # If we have a repeated function call, ignore it
264
- if fn_call not in prev_function_calls :
265
- autogen_tool_calls .append (
266
- ChatCompletionMessageToolCall (
267
- id = random_id ,
268
- function = {
269
- "name" : fn_call .name ,
270
- "arguments" : (
271
- json .dumps ({key : val for key , val in fn_call .args .items ()})
272
- if fn_call .args is not None
273
- else ""
274
- ),
275
- },
276
- type = "function" ,
277
- )
254
+ genai .configure (api_key = self .api_key )
255
+ chat = model .start_chat (history = gemini_messages [:- 1 ])
256
+
257
+ response = chat .send_message (gemini_messages [- 1 ].parts , stream = stream , safety_settings = safety_settings )
258
+
259
+ # Extract text and tools from response
260
+ ans = ""
261
+ random_id = random .randint (0 , 10000 )
262
+ prev_function_calls = []
263
+ for part in response .parts :
264
+
265
+ # Function calls
266
+ if fn_call := part .function_call :
267
+
268
+ # If we have a repeated function call, ignore it
269
+ if fn_call not in prev_function_calls :
270
+ autogen_tool_calls .append (
271
+ ChatCompletionMessageToolCall (
272
+ id = random_id ,
273
+ function = {
274
+ "name" : fn_call .name ,
275
+ "arguments" : (
276
+ json .dumps ({key : val for key , val in fn_call .args .items ()})
277
+ if fn_call .args is not None
278
+ else ""
279
+ ),
280
+ },
281
+ type = "function" ,
278
282
)
283
+ )
279
284
280
- prev_function_calls .append (fn_call )
281
- random_id += 1
282
-
283
- # Plain text content
284
- elif text := part .text :
285
- ans += text
286
-
287
- # If we have function calls, ignore the text
288
- # as it can be Gemini guessing the function response
289
- if len (autogen_tool_calls ) != 0 :
290
- ans = ""
291
-
292
- prompt_tokens = response .usage_metadata .prompt_token_count
293
- completion_tokens = response .usage_metadata .candidates_token_count
285
+ prev_function_calls .append (fn_call )
286
+ random_id += 1
294
287
295
- elif model_name == "gemini-pro-vision" :
296
- # B. handle the vision model
297
- if self .use_vertexai :
298
- model = GenerativeModel (
299
- model_name ,
300
- generation_config = generation_config ,
301
- safety_settings = safety_settings ,
302
- system_instruction = system_instruction ,
303
- )
304
- else :
305
- model = genai .GenerativeModel (
306
- model_name ,
307
- generation_config = generation_config ,
308
- safety_settings = safety_settings ,
309
- system_instruction = system_instruction ,
310
- )
311
- genai .configure (api_key = self .api_key )
312
- # Gemini's vision model does not support chat history yet
313
- # chat = model.start_chat(history=gemini_messages[:-1])
314
- # response = chat.send_message(gemini_messages[-1].parts)
315
- user_message = self ._oai_content_to_gemini_content (messages [- 1 ]["content" ])
316
- if len (messages ) > 2 :
317
- warnings .warn (
318
- "Warning: Gemini's vision model does not support chat history yet." ,
319
- "We only use the last message as the prompt." ,
320
- UserWarning ,
321
- )
288
+ # Plain text content
289
+ elif text := part .text :
290
+ ans += text
322
291
323
- response = model .generate_content (user_message , stream = stream )
324
- if self .use_vertexai :
325
- ans : str = response .candidates [0 ].content .parts [0 ].text
326
- else :
327
- ans : str = response ._result .candidates [0 ].content .parts [0 ].text
292
+ # If we have function calls, ignore the text
293
+ # as it can be Gemini guessing the function response
294
+ if len (autogen_tool_calls ) != 0 :
295
+ ans = ""
296
+ else :
297
+ autogen_tool_calls = None
328
298
329
- prompt_tokens = response .usage_metadata .prompt_token_count
330
- completion_tokens = response .usage_metadata .candidates_token_count
299
+ prompt_tokens = response .usage_metadata .prompt_token_count
300
+ completion_tokens = response .usage_metadata .candidates_token_count
331
301
332
302
# 3. convert output
333
303
message = ChatCompletionMessage (
334
304
role = "assistant" , content = ans , function_call = None , tool_calls = autogen_tool_calls
335
305
)
336
306
choices = [
337
- Choice (finish_reason = "tool_calls" if len ( autogen_tool_calls ) > 0 else "stop" , index = 0 , message = message )
307
+ Choice (finish_reason = "tool_calls" if autogen_tool_calls is not None else "stop" , index = 0 , message = message )
338
308
]
339
309
340
310
response_oai = ChatCompletion (
0 commit comments