1
1
import {
2
2
ChatCompletionCreateParamsWithModel ,
3
+ GenericChatCompletion ,
3
4
InstructorConfig ,
4
5
LogLevel ,
5
- ReturnTypeBasedOnParams
6
+ OpenAILikeClient ,
7
+ ReturnTypeBasedOnParams ,
8
+ SupportedInstructorClient
6
9
} from "@/types"
7
10
import OpenAI from "openai"
8
11
import { z } from "zod"
@@ -17,22 +20,22 @@ import {
17
20
PROVIDER_SUPPORTED_MODES_BY_MODEL ,
18
21
PROVIDERS
19
22
} from "./constants/providers"
20
- import { CompletionMeta } from "./types"
23
+ import { ClientTypeChatCompletionParams , CompletionMeta } from "./types"
21
24
22
25
const MAX_RETRIES_DEFAULT = 0
23
26
24
- class Instructor {
25
- readonly client : OpenAI
27
+ class Instructor < C extends SupportedInstructorClient > {
28
+ readonly client : OpenAILikeClient < C >
26
29
readonly mode : Mode
27
30
readonly provider : Provider
28
31
readonly debug : boolean = false
29
32
30
33
/**
31
34
* Creates an instance of the `Instructor` class.
32
- * @param {OpenAI } client - The OpenAI client.
35
+ * @param {OpenAILikeClient } client - An OpenAI-like client.
33
36
* @param {string } mode - The mode of operation.
34
37
*/
35
- constructor ( { client, mode, debug = false } : InstructorConfig ) {
38
+ constructor ( { client, mode, debug = false } : InstructorConfig < C > ) {
36
39
this . client = client
37
40
this . mode = mode
38
41
this . debug = debug
@@ -41,6 +44,7 @@ class Instructor {
41
44
this . client ?. baseURL . includes ( NON_OAI_PROVIDER_URLS . ANYSCALE ) ? PROVIDERS . ANYSCALE
42
45
: this . client ?. baseURL . includes ( NON_OAI_PROVIDER_URLS . TOGETHER ) ? PROVIDERS . TOGETHER
43
46
: this . client ?. baseURL . includes ( NON_OAI_PROVIDER_URLS . OAI ) ? PROVIDERS . OAI
47
+ : this . client ?. baseURL . includes ( NON_OAI_PROVIDER_URLS . ANTHROPIC ) ? PROVIDERS . ANTHROPIC
44
48
: PROVIDERS . OTHER
45
49
46
50
this . provider = provider
@@ -137,10 +141,12 @@ class Instructor {
137
141
}
138
142
}
139
143
140
- let completion : OpenAI . Chat . Completions . ChatCompletion | null = null
144
+ let completion : GenericChatCompletion | null = null
141
145
142
146
try {
143
- completion = await this . client . chat . completions . create ( resolvedParams )
147
+ completion = ( await this . client . chat . completions . create (
148
+ resolvedParams
149
+ ) ) as GenericChatCompletion
144
150
this . log ( "debug" , "raw standard completion response: " , completion )
145
151
} catch ( error ) {
146
152
this . log (
@@ -258,7 +264,8 @@ class Instructor {
258
264
this . log ( "debug" , "raw stream completion response: " , completion )
259
265
260
266
return OAIStream ( {
261
- res : completion
267
+ //TODO: we need to move away from strict openai types - need to cast here but should update to be more flexible
268
+ res : completion as AsyncIterable < OpenAI . ChatCompletionChunk >
262
269
} )
263
270
} ,
264
271
response_model
@@ -282,41 +289,46 @@ class Instructor {
282
289
create : async <
283
290
T extends z . AnyZodObject ,
284
291
P extends T extends z . AnyZodObject ? ChatCompletionCreateParamsWithModel < T >
285
- : OpenAI . ChatCompletionCreateParams & { response_model : never }
292
+ : ClientTypeChatCompletionParams < typeof this . client > & { response_model : never }
286
293
> (
287
294
params : P
288
- ) : Promise < ReturnTypeBasedOnParams < P > > => {
295
+ ) : Promise < ReturnTypeBasedOnParams < typeof this . client , P > > => {
289
296
this . validateModelModeSupport ( params )
290
297
291
298
if ( this . isChatCompletionCreateParamsWithModel ( params ) ) {
292
299
if ( params . stream ) {
293
300
return this . chatCompletionStream ( params ) as ReturnTypeBasedOnParams <
301
+ typeof this . client ,
294
302
P & { stream : true }
295
303
>
296
304
} else {
297
- return this . chatCompletionStandard ( params ) as ReturnTypeBasedOnParams < P >
305
+ return this . chatCompletionStandard ( params ) as ReturnTypeBasedOnParams <
306
+ typeof this . client ,
307
+ P
308
+ >
298
309
}
299
310
} else {
300
- const result : OpenAI . Chat . Completions . ChatCompletion =
311
+ const result =
301
312
this . isStandardStream ( params ) ?
302
313
await this . client . chat . completions . create ( params )
303
314
: await this . client . chat . completions . create ( params )
304
315
305
- return result as ReturnTypeBasedOnParams < P >
316
+ return result as ReturnTypeBasedOnParams < typeof this . client , P >
306
317
}
307
318
}
308
319
}
309
320
}
310
321
}
311
322
312
- export type OAIClientExtended = OpenAI & Instructor
323
+ export type InstructorClient < C extends SupportedInstructorClient = OpenAI > = Instructor < C > &
324
+ OpenAILikeClient < C >
313
325
314
326
/**
315
327
* Creates an instance of the `Instructor` class.
316
- * @param {OpenAI } client - The OpenAI client.
328
+ * @param {OpenAILikeClient } client - The OpenAI client.
317
329
* @param {string } mode - The mode of operation.
318
330
* @param {boolean } debug - Whether to log debug messages.
319
- * @returns {OAIClientExtended } The extended OpenAI client.
331
+ * @returns {InstructorClient } The extended OpenAI client.
320
332
*
321
333
* @example
322
334
* import createInstructor from "@instructor-ai/instructor"
@@ -326,24 +338,26 @@ export type OAIClientExtended = OpenAI & Instructor
326
338
*
327
339
* const client = createInstructor({
328
340
* client: OAI,
329
- * mode: "TOOLS",
341
+ * mode: "TOOLS",
330
342
* })
331
343
*
332
344
* @param args
333
345
* @returns
334
346
*/
335
- export default function ( args : { client : OpenAI ; mode : Mode ; debug ?: boolean } ) : OAIClientExtended {
336
- const instructor = new Instructor ( args )
337
-
347
+ export default function < C extends SupportedInstructorClient = OpenAI > ( args : {
348
+ client : OpenAILikeClient < C >
349
+ mode : Mode
350
+ debug ?: boolean
351
+ } ) : InstructorClient < C > {
352
+ const instructor = new Instructor < C > ( args )
338
353
const instructorWithProxy = new Proxy ( instructor , {
339
354
get : ( target , prop , receiver ) => {
340
355
if ( prop in target ) {
341
356
return Reflect . get ( target , prop , receiver )
342
357
}
343
-
344
358
return Reflect . get ( target . client , prop , receiver )
345
359
}
346
360
} )
347
361
348
- return instructorWithProxy as OAIClientExtended
362
+ return instructorWithProxy as InstructorClient < C >
349
363
}
0 commit comments