Skip to content

Commit be4bcca

Browse files
committed
🔥 refactor: clean openai azure code
1 parent 6ceb818 commit be4bcca

File tree

14 files changed

+108
-259
lines changed

14 files changed

+108
-259
lines changed

src/app/api/chat/[provider]/agentRuntime.test.ts

+12-12
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {
1818
LobeOpenAI,
1919
LobeOpenRouterAI,
2020
LobePerplexityAI,
21+
LobeRuntimeAI,
2122
LobeTogetherAI,
2223
LobeZhipuAI,
2324
ModelProvider,
@@ -70,33 +71,32 @@ describe('AgentRuntime', () => {
7071
const jwtPayload: JWTPayload = {
7172
apiKey: 'user-azure-key',
7273
endpoint: 'user-azure-endpoint',
73-
useAzure: true,
74+
azureApiVersion: '2024-02-01',
7475
};
7576
const runtime = await AgentRuntime.initializeWithUserPayload(
76-
ModelProvider.OpenAI,
77+
ModelProvider.Azure,
7778
jwtPayload,
7879
);
7980

8081
expect(runtime).toBeInstanceOf(AgentRuntime);
81-
expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI);
82+
expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI);
8283
expect(runtime['_runtime'].baseURL).toBe('user-azure-endpoint');
8384
});
8485
it('should initialize with azureOpenAIParams correctly', async () => {
85-
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' };
86-
const azureOpenAIParams = {
87-
apiVersion: 'custom-version',
88-
model: 'custom-model',
89-
useAzure: true,
86+
const jwtPayload: JWTPayload = {
87+
apiKey: 'user-openai-key',
88+
endpoint: 'user-endpoint',
89+
azureApiVersion: 'custom-version',
9090
};
91+
9192
const runtime = await AgentRuntime.initializeWithUserPayload(
92-
ModelProvider.OpenAI,
93+
ModelProvider.Azure,
9394
jwtPayload,
94-
azureOpenAIParams,
9595
);
9696

9797
expect(runtime).toBeInstanceOf(AgentRuntime);
98-
const openAIRuntime = runtime['_runtime'] as LobeOpenAI;
99-
expect(openAIRuntime).toBeInstanceOf(LobeOpenAI);
98+
const openAIRuntime = runtime['_runtime'] as LobeRuntimeAI;
99+
expect(openAIRuntime).toBeInstanceOf(LobeAzureOpenAI);
100100
});
101101

102102
it('should initialize with AzureAI correctly', async () => {

src/app/api/chat/[provider]/agentRuntime.ts

+6-29
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ import { TraceClient } from '@/libs/traces';
3030

3131
import apiKeyManager from '../apiKeyManager';
3232

33-
interface AzureOpenAIParams {
34-
apiVersion?: string;
35-
model: string;
36-
useAzure?: boolean;
37-
}
38-
3933
export interface AgentChatOptions {
4034
enableTrace?: boolean;
4135
provider: string;
@@ -112,18 +106,14 @@ class AgentRuntime {
112106
});
113107
}
114108

115-
static async initializeWithUserPayload(
116-
provider: string,
117-
payload: JWTPayload,
118-
azureOpenAI?: AzureOpenAIParams,
119-
) {
109+
static async initializeWithUserPayload(provider: string, payload: JWTPayload) {
120110
let runtimeModel: LobeRuntimeAI;
121111

122112
switch (provider) {
123113
default:
124114
case 'oneapi':
125115
case ModelProvider.OpenAI: {
126-
runtimeModel = this.initOpenAI(payload, azureOpenAI);
116+
runtimeModel = this.initOpenAI(payload);
127117
break;
128118
}
129119

@@ -196,27 +186,14 @@ class AgentRuntime {
196186
return new AgentRuntime(runtimeModel);
197187
}
198188

199-
private static initOpenAI(payload: JWTPayload, azureOpenAI?: AzureOpenAIParams) {
200-
const { OPENAI_API_KEY, OPENAI_PROXY_URL, AZURE_API_VERSION, AZURE_API_KEY, USE_AZURE_OPENAI } =
201-
getServerConfig();
189+
private static initOpenAI(payload: JWTPayload) {
190+
const { OPENAI_API_KEY, OPENAI_PROXY_URL } = getServerConfig();
202191
const openaiApiKey = payload?.apiKey || OPENAI_API_KEY;
203192
const baseURL = payload?.endpoint || OPENAI_PROXY_URL;
204193

205-
const azureApiKey = payload.apiKey || AZURE_API_KEY;
206-
const useAzure = azureOpenAI?.useAzure || USE_AZURE_OPENAI;
207-
const apiVersion = azureOpenAI?.apiVersion || AZURE_API_VERSION;
194+
const apiKey = apiKeyManager.pick(openaiApiKey);
208195

209-
const apiKey = apiKeyManager.pick(useAzure ? azureApiKey : openaiApiKey);
210-
211-
return new LobeOpenAI({
212-
apiKey,
213-
azureOptions: {
214-
apiVersion,
215-
model: azureOpenAI?.model,
216-
},
217-
baseURL,
218-
useAzure,
219-
});
196+
return new LobeOpenAI({ apiKey, baseURL });
220197
}
221198

222199
private static initAzureOpenAI(payload: JWTPayload) {

src/app/api/chat/[provider]/route.test.ts

+1-6
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ describe('POST handler', () => {
4242
accessCode: 'test-access-code',
4343
apiKey: 'test-api-key',
4444
azureApiVersion: 'v1',
45-
useAzure: true,
4645
});
4746

4847
const mockRuntime: LobeRuntimeAI = { baseURL: 'abc', chat: vi.fn() };
@@ -56,11 +55,7 @@ describe('POST handler', () => {
5655

5756
// 验证是否正确调用了模拟函数
5857
expect(getJWTPayload).toHaveBeenCalledWith('Bearer some-valid-token');
59-
expect(spy).toHaveBeenCalledWith('test-provider', expect.anything(), {
60-
apiVersion: 'v1',
61-
model: 'test-model',
62-
useAzure: true,
63-
});
58+
expect(spy).toHaveBeenCalledWith('test-provider', expect.anything());
6459
});
6560

6661
it('should return Unauthorized error when LOBE_CHAT_AUTH_HEADER is missing', async () => {

src/app/api/chat/[provider]/route.ts

+1-6
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,7 @@ export const POST = async (req: Request, { params }: { params: { provider: strin
2929
const jwtPayload = await getJWTPayload(authorization);
3030
checkAuthMethod(jwtPayload.accessCode, jwtPayload.apiKey, oauthAuthorized);
3131

32-
const body = await req.clone().json();
33-
const agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, jwtPayload, {
34-
apiVersion: jwtPayload.azureApiVersion,
35-
model: body.model,
36-
useAzure: jwtPayload.useAzure,
37-
});
32+
const agentRuntime = await AgentRuntime.initializeWithUserPayload(provider, jwtPayload);
3833

3934
// ============ 2. create chat completion ============ //
4035

src/const/auth.ts

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ export interface JWTPayload {
2424
endpoint?: string;
2525

2626
azureApiVersion?: string;
27-
useAzure?: boolean;
2827

2928
awsAccessKeyId?: string;
3029
awsRegion?: string;

src/libs/agent-runtime/groq/index.test.ts

+10-7
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,16 @@ describe('LobeGroqAI', () => {
7575
});
7676

7777
// Assert
78-
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({
79-
max_tokens: 1024,
80-
messages: [{ content: 'Hello', role: 'user' }],
81-
model: 'mistralai/mistral-7b-instruct:free',
82-
temperature: 0.7,
83-
top_p: 1,
84-
});
78+
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
79+
{
80+
max_tokens: 1024,
81+
messages: [{ content: 'Hello', role: 'user' }],
82+
model: 'mistralai/mistral-7b-instruct:free',
83+
temperature: 0.7,
84+
top_p: 1,
85+
},
86+
{ headers: { Accept: '*/*' } },
87+
);
8588
expect(result).toBeInstanceOf(Response);
8689
});
8790

src/libs/agent-runtime/mistral/index.test.ts

+22-16
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,17 @@ describe('LobeMistralAI', () => {
7575
});
7676

7777
// Assert
78-
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({
79-
max_tokens: 1024,
80-
messages: [{ content: 'Hello', role: 'user' }],
81-
model: 'open-mistral-7b',
82-
stream: true,
83-
temperature: 0.7,
84-
top_p: 1,
85-
});
78+
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
79+
{
80+
max_tokens: 1024,
81+
messages: [{ content: 'Hello', role: 'user' }],
82+
model: 'open-mistral-7b',
83+
stream: true,
84+
temperature: 0.7,
85+
top_p: 1,
86+
},
87+
{ headers: { Accept: '*/*' } },
88+
);
8689
expect(result).toBeInstanceOf(Response);
8790
});
8891

@@ -105,14 +108,17 @@ describe('LobeMistralAI', () => {
105108
});
106109

107110
// Assert
108-
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({
109-
max_tokens: 1024,
110-
messages: [{ content: 'Hello', role: 'user' }],
111-
model: 'open-mistral-7b',
112-
stream: true,
113-
temperature: 0.7,
114-
top_p: 1,
115-
});
111+
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith(
112+
{
113+
max_tokens: 1024,
114+
messages: [{ content: 'Hello', role: 'user' }],
115+
model: 'open-mistral-7b',
116+
stream: true,
117+
temperature: 0.7,
118+
top_p: 1,
119+
},
120+
{ headers: { Accept: '*/*' } },
121+
);
116122
expect(result).toBeInstanceOf(Response);
117123
});
118124

src/libs/agent-runtime/openai/index.test.ts

+2-50
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import OpenAI from 'openai';
33
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
44

55
// 引入模块以便于对函数进行spy
6-
import { ChatStreamCallbacks } from '@/libs/agent-runtime';
6+
import { ChatStreamCallbacks, LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
77

88
import * as debugStreamModule from '../utils/debugStream';
99
import { LobeOpenAI } from './index';
@@ -12,7 +12,7 @@ import { LobeOpenAI } from './index';
1212
vi.spyOn(console, 'error').mockImplementation(() => {});
1313

1414
describe('LobeOpenAI', () => {
15-
let instance: LobeOpenAI;
15+
let instance: LobeOpenAICompatibleRuntime;
1616

1717
beforeEach(() => {
1818
instance = new LobeOpenAI({ apiKey: 'test' });
@@ -27,54 +27,6 @@ describe('LobeOpenAI', () => {
2727
vi.clearAllMocks();
2828
});
2929

30-
describe('init', () => {
31-
it('should correctly initialize with Azure options', () => {
32-
const baseURL = 'https://abc.com';
33-
const modelName = 'abc';
34-
const client = new LobeOpenAI({
35-
apiKey: 'test',
36-
useAzure: true,
37-
baseURL,
38-
azureOptions: {
39-
apiVersion: '2023-08-01-preview',
40-
model: 'abc',
41-
},
42-
});
43-
44-
expect(client.baseURL).toEqual(baseURL + '/openai/deployments/' + modelName);
45-
});
46-
47-
describe('initWithAzureOpenAI', () => {
48-
it('should correctly initialize with Azure options', () => {
49-
const baseURL = 'https://abc.com';
50-
const modelName = 'abc';
51-
const client = LobeOpenAI.initWithAzureOpenAI({
52-
apiKey: 'test',
53-
useAzure: true,
54-
baseURL,
55-
azureOptions: {
56-
apiVersion: '2023-08-01-preview',
57-
model: 'abc',
58-
},
59-
});
60-
61-
expect(client.baseURL).toEqual(baseURL + '/openai/deployments/' + modelName);
62-
});
63-
64-
it('should use default Azure options when not explicitly provided', () => {
65-
const baseURL = 'https://abc.com';
66-
67-
const client = LobeOpenAI.initWithAzureOpenAI({
68-
apiKey: 'test',
69-
useAzure: true,
70-
baseURL,
71-
});
72-
73-
expect(client.baseURL).toEqual(baseURL + '/openai/deployments/');
74-
});
75-
});
76-
});
77-
7830
describe('chat', () => {
7931
it('should return a StreamingTextResponse on successful API call', async () => {
8032
// Arrange

0 commit comments

Comments
 (0)