|
19 | 19 | ThreadedGenerator,
|
20 | 20 | commit_conversation_trace,
|
21 | 21 | )
|
22 |
| -from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled |
| 22 | +from khoj.utils.helpers import ( |
| 23 | + get_chat_usage_metrics, |
| 24 | + get_openai_client, |
| 25 | + is_promptrace_enabled, |
| 26 | +) |
23 | 27 |
|
24 | 28 | logger = logging.getLogger(__name__)
|
25 | 29 |
|
@@ -51,10 +55,7 @@ def completion_with_backoff(
|
51 | 55 | client_key = f"{openai_api_key}--{api_base_url}"
|
52 | 56 | client: openai.OpenAI | None = openai_clients.get(client_key)
|
53 | 57 | if not client:
|
54 |
| - client = openai.OpenAI( |
55 |
| - api_key=openai_api_key, |
56 |
| - base_url=api_base_url, |
57 |
| - ) |
| 58 | + client = get_openai_client(openai_api_key, api_base_url) |
58 | 59 | openai_clients[client_key] = client
|
59 | 60 |
|
60 | 61 | formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
@@ -158,14 +159,11 @@ def llm_thread(
|
158 | 159 | ):
|
159 | 160 | try:
|
160 | 161 | client_key = f"{openai_api_key}--{api_base_url}"
|
161 |
| - if client_key not in openai_clients: |
162 |
| - client = openai.OpenAI( |
163 |
| - api_key=openai_api_key, |
164 |
| - base_url=api_base_url, |
165 |
| - ) |
166 |
| - openai_clients[client_key] = client |
167 |
| - else: |
| 162 | + if client_key in openai_clients: |
168 | 163 | client = openai_clients[client_key]
|
| 164 | + else: |
| 165 | + client = get_openai_client(openai_api_key, api_base_url) |
| 166 | + openai_clients[client_key] = client |
169 | 167 |
|
170 | 168 | formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
171 | 169 |
|
|
0 commit comments