|
19 | 19 | ThreadedGenerator,
|
20 | 20 | commit_conversation_trace,
|
21 | 21 | )
|
22 |
| -from khoj.utils.helpers import get_chat_usage_metrics, is_promptrace_enabled |
| 22 | +from khoj.utils.helpers import ( |
| 23 | + get_chat_usage_metrics, |
| 24 | + get_openai_client, |
| 25 | + is_promptrace_enabled, |
| 26 | +) |
23 | 27 |
|
24 | 28 | logger = logging.getLogger(__name__)
|
25 | 29 |
|
@@ -51,10 +55,7 @@ def completion_with_backoff(
|
51 | 55 | client_key = f"{openai_api_key}--{api_base_url}"
|
52 | 56 | client: openai.OpenAI | None = openai_clients.get(client_key)
|
53 | 57 | if not client:
|
54 |
| - client = openai.OpenAI( |
55 |
| - api_key=openai_api_key, |
56 |
| - base_url=api_base_url, |
57 |
| - ) |
| 58 | + client = get_openai_client(openai_api_key, api_base_url) |
58 | 59 | openai_clients[client_key] = client
|
59 | 60 |
|
60 | 61 | formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
@@ -161,14 +162,11 @@ def llm_thread(
|
161 | 162 | ):
|
162 | 163 | try:
|
163 | 164 | client_key = f"{openai_api_key}--{api_base_url}"
|
164 |
| - if client_key not in openai_clients: |
165 |
| - client = openai.OpenAI( |
166 |
| - api_key=openai_api_key, |
167 |
| - base_url=api_base_url, |
168 |
| - ) |
169 |
| - openai_clients[client_key] = client |
170 |
| - else: |
| 165 | + if client_key in openai_clients: |
171 | 166 | client = openai_clients[client_key]
|
| 167 | + else: |
| 168 | + client = get_openai_client(openai_api_key, api_base_url) |
| 169 | + openai_clients[client_key] = client |
172 | 170 |
|
173 | 171 | formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
174 | 172 |
|
|
0 commit comments