Skip to content

Commit 0a5bbd8

Browse files
authored
Non-oai usage meta + non-oai client types (#182)
1 parent f386ad7 commit 0a5bbd8

File tree

8 files changed

+123
-27
lines changed

8 files changed

+123
-27
lines changed

.changeset/light-chefs-clean.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@instructor-ai/instructor": minor
3+
---
4+
5+
update client types to better support non oai clients + updates to allow for passing usage properties into meta from non-oai clients

bun.lockb

597 Bytes
Binary file not shown.

package.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"zod": ">=3.22.4"
6060
},
6161
"devDependencies": {
62-
"@anthropic-ai/sdk": "latest",
62+
"@anthropic-ai/sdk": "0.22.0",
6363
"@changesets/changelog-github": "^0.5.0",
6464
"@changesets/cli": "^2.27.1",
6565
"@ianvs/prettier-plugin-sort-imports": "4.1.0",
@@ -75,8 +75,8 @@
7575
"eslint-plugin-only-warn": "^1.1.0",
7676
"eslint-plugin-prettier": "^5.1.2",
7777
"husky": "^8.0.3",
78-
"llm-polyglot": "1.0.0",
79-
"openai": "latest",
78+
"llm-polyglot": "2.0.0",
79+
"openai": "4.50.0",
8080
"prettier": "latest",
8181
"ts-inference-check": "^0.3.0",
8282
"tsup": "^8.0.1",

src/instructor.ts

+50-7
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ import {
2222
PROVIDER_SUPPORTED_MODES_BY_MODEL,
2323
PROVIDERS
2424
} from "./constants/providers"
25+
import { iterableTee } from "./lib"
2526
import { ClientTypeChatCompletionParams, CompletionMeta } from "./types"
2627

2728
const MAX_RETRIES_DEFAULT = 0
2829

29-
class Instructor<C extends GenericClient | OpenAI> {
30+
class Instructor<C> {
3031
readonly client: OpenAILikeClient<C>
3132
readonly mode: Mode
3233
readonly provider: Provider
@@ -46,7 +47,17 @@ class Instructor<C extends GenericClient | OpenAI> {
4647
logger = undefined,
4748
retryAllErrors = false
4849
}: InstructorConfig<C>) {
49-
this.client = client
50+
if (!isGenericClient(client) && !(client instanceof OpenAI)) {
51+
throw new Error("Client does not match the required structure")
52+
}
53+
54+
if (client instanceof OpenAI) {
55+
this.client = client as OpenAI
56+
} else {
57+
this.client = client as C & GenericClient
58+
}
59+
60+
// this.client = client
5061
this.mode = mode
5162
this.debug = debug
5263
this.retryAllErrors = retryAllErrors
@@ -308,7 +319,9 @@ class Instructor<C extends GenericClient | OpenAI> {
308319
debug: this.debug ?? false
309320
})
310321

311-
async function checkForUsage(reader: Stream<OpenAI.ChatCompletionChunk>) {
322+
async function checkForUsage(
323+
reader: Stream<OpenAI.ChatCompletionChunk> | AsyncIterable<OpenAI.ChatCompletionChunk>
324+
) {
312325
for await (const chunk of reader) {
313326
if ("usage" in chunk) {
314327
streamUsage = chunk.usage as CompletionMeta["usage"]
@@ -345,6 +358,24 @@ class Instructor<C extends GenericClient | OpenAI> {
345358
})
346359
}
347360

361+
//check if async iterator
362+
if (
363+
this.provider !== "OAI" &&
364+
completionParams?.stream &&
365+
completion?.[Symbol.asyncIterator]
366+
) {
367+
const [completion1, completion2] = await iterableTee(
368+
completion as AsyncIterable<OpenAI.ChatCompletionChunk>,
369+
2
370+
)
371+
372+
checkForUsage(completion1)
373+
374+
return OAIStream({
375+
res: completion2
376+
})
377+
}
378+
348379
return OAIStream({
349380
res: completion as unknown as AsyncIterable<OpenAI.ChatCompletionChunk>
350381
})
@@ -419,7 +450,7 @@ class Instructor<C extends GenericClient | OpenAI> {
419450
}
420451
}
421452

422-
export type InstructorClient<C extends GenericClient | OpenAI> = Instructor<C> & OpenAILikeClient<C>
453+
export type InstructorClient<C> = Instructor<C> & OpenAILikeClient<C>
423454

424455
/**
425456
* Creates an instance of the `Instructor` class.
@@ -442,9 +473,7 @@ export type InstructorClient<C extends GenericClient | OpenAI> = Instructor<C> &
442473
* @param args
443474
* @returns
444475
*/
445-
export default function createInstructor<C extends GenericClient | OpenAI>(
446-
args: InstructorConfig<C>
447-
): InstructorClient<C> {
476+
export default function createInstructor<C>(args: InstructorConfig<C>): InstructorClient<C> {
448477
const instructor = new Instructor<C>(args)
449478
const instructorWithProxy = new Proxy(instructor, {
450479
get: (target, prop, receiver) => {
@@ -458,3 +487,17 @@ export default function createInstructor<C extends GenericClient | OpenAI>(
458487

459488
return instructorWithProxy as InstructorClient<C>
460489
}
490+
//eslint-disable-next-line @typescript-eslint/no-explicit-any
491+
function isGenericClient(client: any): client is GenericClient {
492+
return (
493+
typeof client === "object" &&
494+
client !== null &&
495+
"baseURL" in client &&
496+
"chat" in client &&
497+
typeof client.chat === "object" &&
498+
"completions" in client.chat &&
499+
typeof client.chat.completions === "object" &&
500+
"create" in client.chat.completions &&
501+
typeof client.chat.completions.create === "function"
502+
)
503+
}

src/lib/index.ts

+42
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,45 @@ export function omit<T extends object, K extends keyof T>(keys: K[], obj: T): Om
77
}
88
return result
99
}
10+
11+
export async function iterableTee<T>(
12+
iterable: AsyncIterable<T>,
13+
n: number
14+
): Promise<AsyncGenerator<T>[]> {
15+
const buffers: T[][] = Array.from({ length: n }, () => [])
16+
const resolvers: (() => void)[] = []
17+
const iterator = iterable[Symbol.asyncIterator]()
18+
let done = false
19+
20+
async function* reader(index: number) {
21+
while (true) {
22+
if (buffers[index].length > 0) {
23+
yield buffers[index].shift()!
24+
} else if (done) {
25+
break
26+
} else {
27+
await new Promise<void>(resolve => resolvers.push(resolve))
28+
}
29+
}
30+
}
31+
32+
;(async () => {
33+
for await (const item of {
34+
[Symbol.asyncIterator]: () => iterator
35+
}) {
36+
for (const buffer of buffers) {
37+
buffer.push(item)
38+
}
39+
40+
while (resolvers.length > 0) {
41+
resolvers.shift()!()
42+
}
43+
}
44+
done = true
45+
while (resolvers.length > 0) {
46+
resolvers.shift()!()
47+
}
48+
})()
49+
50+
return Array.from({ length: n }, (_, i) => reader(i))
51+
}

src/types/index.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ export type GenericClient = {
3939
baseURL?: string
4040
chat?: {
4141
completions?: {
42-
create?: (params: GenericCreateParams) => Promise<unknown>
42+
create?: <P extends GenericCreateParams>(params: P) => Promise<unknown>
4343
}
4444
}
4545
}
@@ -55,7 +55,7 @@ export type ClientType<C> =
5555
: C extends GenericClient ? "generic"
5656
: never
5757

58-
export type OpenAILikeClient<C> = C extends OpenAI ? OpenAI : C & GenericClient
58+
export type OpenAILikeClient<C> = OpenAI | (C & GenericClient)
5959
export type SupportedInstructorClient = GenericClient | OpenAI
6060
export type LogLevel = "debug" | "info" | "warn" | "error"
6161

@@ -68,7 +68,7 @@ export type Mode = ZMode
6868
export type ResponseModel<T extends z.AnyZodObject> = ZResponseModel<T>
6969

7070
export interface InstructorConfig<C> {
71-
client: OpenAILikeClient<C>
71+
client: C
7272
mode: Mode
7373
debug?: boolean
7474
logger?: <T extends unknown[]>(level: LogLevel, ...args: T) => void

tests/anthropic.test.ts

+20-13
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,16 @@ describe("LLMClient Anthropic Provider - mode: TOOLS", () => {
118118
})
119119
})
120120

121-
describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
121+
describe("LLMClient Anthropic Provider - mode: TOOLS - stream", () => {
122122
const instructor = Instructor({
123123
client: anthropicClient,
124-
mode: "MD_JSON"
124+
mode: "TOOLS"
125125
})
126126

127127
test("basic completion", async () => {
128128
const completion = await instructor.chat.completions.create({
129129
model: "claude-3-sonnet-20240229",
130+
stream: true,
130131
max_tokens: 1000,
131132
messages: [
132133
{
@@ -135,17 +136,24 @@ describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
135136
}
136137
],
137138
response_model: {
138-
name: "get_name",
139+
name: "extract_name",
139140
schema: z.object({
140141
name: z.string()
141142
})
142143
}
143144
})
144145

145-
expect(omit(["_meta"], completion)).toEqual({ name: "Dimitri Kennedy" })
146+
let final = {}
147+
148+
for await (const result of completion) {
149+
final = result
150+
}
151+
152+
//@ts-expect-error ignore for testing
153+
expect(omit(["_meta"], final)).toEqual({ name: "Dimitri Kennedy" })
146154
})
147155

148-
test("complex schema - streaming", async () => {
156+
test("complex schema", async () => {
149157
const completion = await instructor.chat.completions.create({
150158
model: "claude-3-sonnet-20240229",
151159
max_tokens: 1000,
@@ -173,14 +181,15 @@ describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
173181
Programming
174182
Leadership
175183
Communication
176-
177-
178184
`
179185
}
180186
],
181187
response_model: {
182188
name: "process_user_data",
183189
schema: z.object({
190+
story: z
191+
.string()
192+
.describe("A long and mostly made up story about the user - minimum 500 words"),
184193
userDetails: z.object({
185194
firstName: z.string(),
186195
lastName: z.string(),
@@ -196,21 +205,19 @@ describe("LLMClient Anthropic Provider - mode: MD_JSON", () => {
196205
years: z.number().optional()
197206
})
198207
),
199-
skills: z.array(z.string()),
200-
summaryOfWorldWarOne: z
201-
.string()
202-
.describe("A detailed summary of World War One and its major events - min 500 words")
208+
skills: z.array(z.string())
203209
})
204210
}
205211
})
206212

207213
let final = {}
214+
208215
for await (const result of completion) {
209216
final = result
210217
}
211218

212-
//@ts-expect-error - lazy
213-
expect(omit(["_meta", "summaryOfWorldWarOne"], final)).toEqual({
219+
//@ts-expect-error ignore for testing
220+
expect(omit(["_meta", "story"], final)).toEqual({
214221
userDetails: {
215222
firstName: "John",
216223
lastName: "Doe",

tests/stream.test.ts

-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ async function extractUser() {
5959
let extraction: Extraction = {}
6060

6161
for await (const result of extractionStream) {
62-
console.log(result)
6362
try {
6463
extraction = result
6564
expect(result).toHaveProperty("users")

0 commit comments

Comments
 (0)