Skip to content

Commit 89adf9d

Browse files
committed
♻️ refactor: refactor agent runtime with openai compatible factory
1 parent 78a1aac commit 89adf9d

File tree

14 files changed

+922
-500
lines changed

14 files changed

+922
-500
lines changed

src/libs/agent-runtime/BaseAI.ts

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
import { StreamingTextResponse } from 'ai';
2+
import OpenAI from 'openai';
23

34
import { ChatCompetitionOptions, ChatStreamPayload } from './types';
45

56
export interface LobeRuntimeAI {
67
baseURL?: string;
7-
88
chat(
99
payload: ChatStreamPayload,
1010
options?: ChatCompetitionOptions,
1111
): Promise<StreamingTextResponse>;
1212
}
13+
14+
export abstract class LobeOpenAICompatibleRuntime {
15+
abstract chat(
16+
payload: ChatStreamPayload,
17+
options?: ChatCompetitionOptions,
18+
): Promise<StreamingTextResponse>;
19+
20+
abstract client: OpenAI;
21+
22+
abstract baseURL: string;
23+
}
+347
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
// @vitest-environment node
2+
import OpenAI from 'openai';
3+
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
4+
5+
import { ChatStreamCallbacks, LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
6+
7+
import * as debugStreamModule from '../utils/debugStream';
8+
import { LobeGroq } from './index';
9+
10+
const provider = 'groq';
11+
const defaultBaseURL = 'https://api.groq.com/openai/v1';
12+
const bizErrorType = 'GroqBizError';
13+
const invalidErrorType = 'InvalidGroqAPIKey';
14+
15+
// Mock the console.error to avoid polluting test output
16+
vi.spyOn(console, 'error').mockImplementation(() => {});
17+
18+
let instance: LobeOpenAICompatibleRuntime;
19+
20+
beforeEach(() => {
21+
instance = new LobeGroq({ apiKey: 'test' });
22+
23+
// 使用 vi.spyOn 来模拟 chat.completions.create 方法
24+
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
25+
new ReadableStream() as any,
26+
);
27+
});
28+
29+
afterEach(() => {
30+
vi.clearAllMocks();
31+
});
32+
33+
describe('LobeGroqAI', () => {
34+
describe('init', () => {
35+
it('should correctly initialize with an API key', async () => {
36+
const instance = new LobeGroq({ apiKey: 'test_api_key' });
37+
expect(instance).toBeInstanceOf(LobeGroq);
38+
expect(instance.baseURL).toEqual(defaultBaseURL);
39+
});
40+
});
41+
42+
describe('chat', () => {
43+
it('should return a StreamingTextResponse on successful API call', async () => {
44+
// Arrange
45+
const mockStream = new ReadableStream();
46+
const mockResponse = Promise.resolve(mockStream);
47+
48+
(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);
49+
50+
// Act
51+
const result = await instance.chat({
52+
messages: [{ content: 'Hello', role: 'user' }],
53+
model: 'mistralai/mistral-7b-instruct:free',
54+
temperature: 0,
55+
});
56+
57+
// Assert
58+
expect(result).toBeInstanceOf(Response);
59+
});
60+
61+
it('should call OpenRouter API with corresponding options', async () => {
62+
// Arrange
63+
const mockStream = new ReadableStream();
64+
const mockResponse = Promise.resolve(mockStream);
65+
66+
(instance['client'].chat.completions.create as Mock).mockResolvedValue(mockResponse);
67+
68+
// Act
69+
const result = await instance.chat({
70+
max_tokens: 1024,
71+
messages: [{ content: 'Hello', role: 'user' }],
72+
model: 'mistralai/mistral-7b-instruct:free',
73+
temperature: 0.7,
74+
top_p: 1,
75+
});
76+
77+
// Assert
78+
expect(instance['client'].chat.completions.create).toHaveBeenCalledWith({
79+
max_tokens: 1024,
80+
messages: [{ content: 'Hello', role: 'user' }],
81+
model: 'mistralai/mistral-7b-instruct:free',
82+
temperature: 0.7,
83+
top_p: 1,
84+
});
85+
expect(result).toBeInstanceOf(Response);
86+
});
87+
88+
describe('Error', () => {
89+
it('should return OpenRouterBizError with an openai error response when OpenAI.APIError is thrown', async () => {
90+
// Arrange
91+
const apiError = new OpenAI.APIError(
92+
400,
93+
{
94+
status: 400,
95+
error: {
96+
message: 'Bad Request',
97+
},
98+
},
99+
'Error message',
100+
{},
101+
);
102+
103+
vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
104+
105+
// Act
106+
try {
107+
await instance.chat({
108+
messages: [{ content: 'Hello', role: 'user' }],
109+
model: 'mistralai/mistral-7b-instruct:free',
110+
temperature: 0,
111+
});
112+
} catch (e) {
113+
expect(e).toEqual({
114+
endpoint: defaultBaseURL,
115+
error: {
116+
error: { message: 'Bad Request' },
117+
status: 400,
118+
},
119+
errorType: bizErrorType,
120+
provider,
121+
});
122+
}
123+
});
124+
125+
it('should throw AgentRuntimeError with InvalidOpenRouterAPIKey if no apiKey is provided', async () => {
126+
try {
127+
new LobeGroq({});
128+
} catch (e) {
129+
expect(e).toEqual({ errorType: invalidErrorType });
130+
}
131+
});
132+
133+
it('should return OpenRouterBizError with the cause when OpenAI.APIError is thrown with cause', async () => {
134+
// Arrange
135+
const errorInfo = {
136+
stack: 'abc',
137+
cause: {
138+
message: 'api is undefined',
139+
},
140+
};
141+
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});
142+
143+
vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
144+
145+
// Act
146+
try {
147+
await instance.chat({
148+
messages: [{ content: 'Hello', role: 'user' }],
149+
model: 'mistralai/mistral-7b-instruct:free',
150+
temperature: 0,
151+
});
152+
} catch (e) {
153+
expect(e).toEqual({
154+
endpoint: defaultBaseURL,
155+
error: {
156+
cause: { message: 'api is undefined' },
157+
stack: 'abc',
158+
},
159+
errorType: bizErrorType,
160+
provider,
161+
});
162+
}
163+
});
164+
165+
it('should return OpenRouterBizError with an cause response with desensitize Url', async () => {
166+
// Arrange
167+
const errorInfo = {
168+
stack: 'abc',
169+
cause: { message: 'api is undefined' },
170+
};
171+
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});
172+
173+
instance = new LobeGroq({
174+
apiKey: 'test',
175+
176+
baseURL: 'https://api.abc.com/v1',
177+
});
178+
179+
vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);
180+
181+
// Act
182+
try {
183+
await instance.chat({
184+
messages: [{ content: 'Hello', role: 'user' }],
185+
model: 'mistralai/mistral-7b-instruct:free',
186+
temperature: 0,
187+
});
188+
} catch (e) {
189+
expect(e).toEqual({
190+
endpoint: 'https://api.***.com/v1',
191+
error: {
192+
cause: { message: 'api is undefined' },
193+
stack: 'abc',
194+
},
195+
errorType: bizErrorType,
196+
provider,
197+
});
198+
}
199+
});
200+
201+
it('should throw an InvalidOpenRouterAPIKey error type on 401 status code', async () => {
202+
// Mock the API call to simulate a 401 error
203+
const error = new Error('Unauthorized') as any;
204+
error.status = 401;
205+
vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error);
206+
207+
try {
208+
await instance.chat({
209+
messages: [{ content: 'Hello', role: 'user' }],
210+
model: 'mistralai/mistral-7b-instruct:free',
211+
temperature: 0,
212+
});
213+
} catch (e) {
214+
// Expect the chat method to throw an error with InvalidMoonshotAPIKey
215+
expect(e).toEqual({
216+
endpoint: defaultBaseURL,
217+
error: new Error('Unauthorized'),
218+
errorType: invalidErrorType,
219+
provider,
220+
});
221+
}
222+
});
223+
224+
it('should return AgentRuntimeError for non-OpenAI errors', async () => {
225+
// Arrange
226+
const genericError = new Error('Generic Error');
227+
228+
vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError);
229+
230+
// Act
231+
try {
232+
await instance.chat({
233+
messages: [{ content: 'Hello', role: 'user' }],
234+
model: 'mistralai/mistral-7b-instruct:free',
235+
temperature: 0,
236+
});
237+
} catch (e) {
238+
expect(e).toEqual({
239+
endpoint: defaultBaseURL,
240+
errorType: 'AgentRuntimeError',
241+
provider,
242+
error: {
243+
name: genericError.name,
244+
cause: genericError.cause,
245+
message: genericError.message,
246+
stack: genericError.stack,
247+
},
248+
});
249+
}
250+
});
251+
});
252+
253+
describe('LobeGroqAI chat with callback and headers', () => {
254+
it('should handle callback and headers correctly', async () => {
255+
// 模拟 chat.completions.create 方法返回一个可读流
256+
const mockCreateMethod = vi
257+
.spyOn(instance['client'].chat.completions, 'create')
258+
.mockResolvedValue(
259+
new ReadableStream({
260+
start(controller) {
261+
controller.enqueue({
262+
id: 'chatcmpl-8xDx5AETP8mESQN7UB30GxTN2H1SO',
263+
object: 'chat.completion.chunk',
264+
created: 1709125675,
265+
model: 'mistralai/mistral-7b-instruct:free',
266+
system_fingerprint: 'fp_86156a94a0',
267+
choices: [
268+
{ index: 0, delta: { content: 'hello' }, logprobs: null, finish_reason: null },
269+
],
270+
});
271+
controller.close();
272+
},
273+
}) as any,
274+
);
275+
276+
// 准备 callback 和 headers
277+
const mockCallback: ChatStreamCallbacks = {
278+
onStart: vi.fn(),
279+
onToken: vi.fn(),
280+
};
281+
const mockHeaders = { 'Custom-Header': 'TestValue' };
282+
283+
// 执行测试
284+
const result = await instance.chat(
285+
{
286+
messages: [{ content: 'Hello', role: 'user' }],
287+
model: 'mistralai/mistral-7b-instruct:free',
288+
temperature: 0,
289+
},
290+
{ callback: mockCallback, headers: mockHeaders },
291+
);
292+
293+
// 验证 callback 被调用
294+
await result.text(); // 确保流被消费
295+
expect(mockCallback.onStart).toHaveBeenCalled();
296+
expect(mockCallback.onToken).toHaveBeenCalledWith('hello');
297+
298+
// 验证 headers 被正确传递
299+
expect(result.headers.get('Custom-Header')).toEqual('TestValue');
300+
301+
// 清理
302+
mockCreateMethod.mockRestore();
303+
});
304+
});
305+
306+
describe('DEBUG', () => {
307+
it('should call debugStream and return StreamingTextResponse when DEBUG_OPENROUTER_CHAT_COMPLETION is 1', async () => {
308+
// Arrange
309+
const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流
310+
const mockDebugStream = new ReadableStream({
311+
start(controller) {
312+
controller.enqueue('Debug stream content');
313+
controller.close();
314+
},
315+
}) as any;
316+
mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法
317+
318+
// 模拟 chat.completions.create 返回值,包括模拟的 tee 方法
319+
(instance['client'].chat.completions.create as Mock).mockResolvedValue({
320+
tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }],
321+
});
322+
323+
// 保存原始环境变量值
324+
const originalDebugValue = process.env.DEBUG_GROQ_CHAT_COMPLETION;
325+
326+
// 模拟环境变量
327+
process.env.DEBUG_GROQ_CHAT_COMPLETION = '1';
328+
vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve());
329+
330+
// 执行测试
331+
// 运行你的测试函数,确保它会在条件满足时调用 debugStream
332+
// 假设的测试函数调用,你可能需要根据实际情况调整
333+
await instance.chat({
334+
messages: [{ content: 'Hello', role: 'user' }],
335+
model: 'mistralai/mistral-7b-instruct:free',
336+
temperature: 0,
337+
});
338+
339+
// 验证 debugStream 被调用
340+
expect(debugStreamModule.debugStream).toHaveBeenCalled();
341+
342+
// 恢复原始环境变量值
343+
process.env.DEBUG_GROQ_CHAT_COMPLETION = originalDebugValue;
344+
});
345+
});
346+
});
347+
});

0 commit comments

Comments
 (0)