Skip to content

Commit 43c4b36

Browse files
committed
Tests added, removed Vision model support as it has been deprecated.
1 parent e257079 commit 43c4b36

File tree

2 files changed

+135
-249
lines changed

2 files changed

+135
-249
lines changed

autogen/oai/gemini.py

+70-100
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,20 @@ def get_usage(response) -> Dict:
174174
}
175175

176176
def create(self, params: Dict) -> ChatCompletion:
177+
177178
if self.use_vertexai:
178179
self._initialize_vertexai(**params)
179180
else:
180181
assert ("project_id" not in params) and (
181182
"location" not in params
182183
), "Google Cloud project and compute location cannot be set when using an API Key!"
183184
model_name = params.get("model", "gemini-pro")
184-
if not model_name:
185+
186+
if model_name == "gemini-pro-vision":
187+
raise ValueError(
188+
"Gemini 1.0 Pro vision ('gemini-pro-vision') has been deprecated, please consider switching to a different model, for example 'gemini-1.5-flash'."
189+
)
190+
elif not model_name:
185191
raise ValueError(
186192
"Please provide a model name for the Gemini Client. "
187193
"You can configure it in the OAI Config List file. "
@@ -197,7 +203,7 @@ def create(self, params: Dict) -> ChatCompletion:
197203
if "tools" in params:
198204
tools = self._tools_to_gemini_tools(params["tools"])
199205
else:
200-
tools = []
206+
tools = None
201207

202208
generation_config = {
203209
gemini_term: params[autogen_term]
@@ -224,117 +230,81 @@ def create(self, params: Dict) -> ChatCompletion:
224230
# Maps the function call ids to function names so we can inject it into FunctionResponse messages
225231
self.tool_call_function_map: Dict[str, str] = {}
226232

227-
if "vision" not in model_name:
228-
# A. create and call the chat model.
229-
gemini_messages = self._oai_messages_to_gemini_messages(messages)
230-
if self.use_vertexai:
231-
model = GenerativeModel(
232-
model_name,
233-
generation_config=generation_config,
234-
safety_settings=safety_settings,
235-
system_instruction=system_instruction,
236-
tools=tools,
237-
)
238-
239-
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
240-
else:
241-
model = genai.GenerativeModel(
242-
model_name,
243-
generation_config=generation_config,
244-
safety_settings=safety_settings,
245-
system_instruction=system_instruction,
246-
tools=tools,
247-
)
248-
249-
genai.configure(api_key=self.api_key)
250-
chat = model.start_chat(history=gemini_messages[:-1])
233+
# A. create and call the chat model.
234+
gemini_messages = self._oai_messages_to_gemini_messages(messages)
235+
if self.use_vertexai:
236+
model = GenerativeModel(
237+
model_name,
238+
generation_config=generation_config,
239+
safety_settings=safety_settings,
240+
system_instruction=system_instruction,
241+
tools=tools,
242+
)
251243

252-
response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
244+
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
245+
else:
246+
model = genai.GenerativeModel(
247+
model_name,
248+
generation_config=generation_config,
249+
safety_settings=safety_settings,
250+
system_instruction=system_instruction,
251+
tools=tools,
252+
)
253253

254-
# Extract text and tools from response
255-
ans = ""
256-
random_id = random.randint(0, 10000)
257-
prev_function_calls = []
258-
for part in response.parts:
259-
260-
# Function calls
261-
if fn_call := part.function_call:
262-
263-
# If we have a repeated function call, ignore it
264-
if fn_call not in prev_function_calls:
265-
autogen_tool_calls.append(
266-
ChatCompletionMessageToolCall(
267-
id=random_id,
268-
function={
269-
"name": fn_call.name,
270-
"arguments": (
271-
json.dumps({key: val for key, val in fn_call.args.items()})
272-
if fn_call.args is not None
273-
else ""
274-
),
275-
},
276-
type="function",
277-
)
254+
genai.configure(api_key=self.api_key)
255+
chat = model.start_chat(history=gemini_messages[:-1])
256+
257+
response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
258+
259+
# Extract text and tools from response
260+
ans = ""
261+
random_id = random.randint(0, 10000)
262+
prev_function_calls = []
263+
for part in response.parts:
264+
265+
# Function calls
266+
if fn_call := part.function_call:
267+
268+
# If we have a repeated function call, ignore it
269+
if fn_call not in prev_function_calls:
270+
autogen_tool_calls.append(
271+
ChatCompletionMessageToolCall(
272+
id=random_id,
273+
function={
274+
"name": fn_call.name,
275+
"arguments": (
276+
json.dumps({key: val for key, val in fn_call.args.items()})
277+
if fn_call.args is not None
278+
else ""
279+
),
280+
},
281+
type="function",
278282
)
283+
)
279284

280-
prev_function_calls.append(fn_call)
281-
random_id += 1
282-
283-
# Plain text content
284-
elif text := part.text:
285-
ans += text
286-
287-
# If we have function calls, ignore the text
288-
# as it can be Gemini guessing the function response
289-
if len(autogen_tool_calls) != 0:
290-
ans = ""
291-
292-
prompt_tokens = response.usage_metadata.prompt_token_count
293-
completion_tokens = response.usage_metadata.candidates_token_count
285+
prev_function_calls.append(fn_call)
286+
random_id += 1
294287

295-
elif model_name == "gemini-pro-vision":
296-
# B. handle the vision model
297-
if self.use_vertexai:
298-
model = GenerativeModel(
299-
model_name,
300-
generation_config=generation_config,
301-
safety_settings=safety_settings,
302-
system_instruction=system_instruction,
303-
)
304-
else:
305-
model = genai.GenerativeModel(
306-
model_name,
307-
generation_config=generation_config,
308-
safety_settings=safety_settings,
309-
system_instruction=system_instruction,
310-
)
311-
genai.configure(api_key=self.api_key)
312-
# Gemini's vision model does not support chat history yet
313-
# chat = model.start_chat(history=gemini_messages[:-1])
314-
# response = chat.send_message(gemini_messages[-1].parts)
315-
user_message = self._oai_content_to_gemini_content(messages[-1]["content"])
316-
if len(messages) > 2:
317-
warnings.warn(
318-
"Warning: Gemini's vision model does not support chat history yet.",
319-
"We only use the last message as the prompt.",
320-
UserWarning,
321-
)
288+
# Plain text content
289+
elif text := part.text:
290+
ans += text
322291

323-
response = model.generate_content(user_message, stream=stream)
324-
if self.use_vertexai:
325-
ans: str = response.candidates[0].content.parts[0].text
326-
else:
327-
ans: str = response._result.candidates[0].content.parts[0].text
292+
# If we have function calls, ignore the text
293+
# as it can be Gemini guessing the function response
294+
if len(autogen_tool_calls) != 0:
295+
ans = ""
296+
else:
297+
autogen_tool_calls = None
328298

329-
prompt_tokens = response.usage_metadata.prompt_token_count
330-
completion_tokens = response.usage_metadata.candidates_token_count
299+
prompt_tokens = response.usage_metadata.prompt_token_count
300+
completion_tokens = response.usage_metadata.candidates_token_count
331301

332302
# 3. convert output
333303
message = ChatCompletionMessage(
334304
role="assistant", content=ans, function_call=None, tool_calls=autogen_tool_calls
335305
)
336306
choices = [
337-
Choice(finish_reason="tool_calls" if len(autogen_tool_calls) > 0 else "stop", index=0, message=message)
307+
Choice(finish_reason="tool_calls" if autogen_tool_calls is not None else "stop", index=0, message=message)
338308
]
339309

340310
response_oai = ChatCompletion(

0 commit comments

Comments
 (0)