Skip to content

Commit 3cc6597

Browse files
authored
Support Azure OpenAI API endpoint (#1048)
OpenAI chat models deployed on Azure are (ironically) not OpenAI API compatible endpoints. This change enables using OpenAI chat models deployed on Azure with Khoj.
1 parent bac90ad commit 3cc6597

File tree

2 files changed

+28
-12
lines changed

2 files changed

+28
-12
lines changed

src/khoj/processor/conversation/openai/utils.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
ThreadedGenerator,
2020
commit_conversation_trace,
2121
)
22-
from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled
22+
from khoj.utils.helpers import (
23+
get_chat_usage_metrics,
24+
get_openai_client,
25+
is_promptrace_enabled,
26+
)
2327

2428
logger = logging.getLogger(__name__)
2529

@@ -51,10 +55,7 @@ def completion_with_backoff(
5155
client_key = f"{openai_api_key}--{api_base_url}"
5256
client: openai.OpenAI | None = openai_clients.get(client_key)
5357
if not client:
54-
client = openai.OpenAI(
55-
api_key=openai_api_key,
56-
base_url=api_base_url,
57-
)
58+
client = get_openai_client(openai_api_key, api_base_url)
5859
openai_clients[client_key] = client
5960

6061
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
@@ -161,14 +162,11 @@ def llm_thread(
161162
):
162163
try:
163164
client_key = f"{openai_api_key}--{api_base_url}"
164-
if client_key not in openai_clients:
165-
client = openai.OpenAI(
166-
api_key=openai_api_key,
167-
base_url=api_base_url,
168-
)
169-
openai_clients[client_key] = client
170-
else:
165+
if client_key in openai_clients:
171166
client = openai_clients[client_key]
167+
else:
168+
client = get_openai_client(openai_api_key, api_base_url)
169+
openai_clients[client_key] = client
172170

173171
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
174172

src/khoj/utils/helpers.py

+18
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import TYPE_CHECKING, Any, Optional, Union
2323
from urllib.parse import urlparse
2424

25+
import openai
2526
import psutil
2627
import requests
2728
import torch
@@ -596,3 +597,20 @@ def get_chat_usage_metrics(
596597
"output_tokens": prev_usage["output_tokens"] + output_tokens,
597598
"cost": cost or get_cost_of_chat_message(model_name, input_tokens, output_tokens, prev_cost=prev_usage["cost"]),
598599
}
600+
601+
602+
def get_openai_client(api_key: str, api_base_url: str) -> Union[openai.OpenAI, openai.AzureOpenAI]:
603+
"""Get OpenAI or AzureOpenAI client based on the API Base URL"""
604+
parsed_url = urlparse(api_base_url)
605+
if parsed_url.hostname and parsed_url.hostname.endswith(".openai.azure.com"):
606+
client = openai.AzureOpenAI(
607+
api_key=api_key,
608+
azure_endpoint=api_base_url,
609+
api_version="2024-10-21",
610+
)
611+
else:
612+
client = openai.OpenAI(
613+
api_key=api_key,
614+
base_url=api_base_url,
615+
)
616+
return client

0 commit comments

Comments
 (0)