Skip to content

Commit f65672c

Browse files
authored
adding provider specific transformers + better logging + fixing anyscale (#132)
1 parent 076aaa6 commit f65672c

File tree

7 files changed

+131
-29
lines changed

7 files changed

+131
-29
lines changed

.changeset/stupid-ducks-act.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@instructor-ai/instructor": minor
3+
---
4+
5+
adding meta to standard completions as well and including usage - also added more verbose debug logs and new provider specific transformers to handle discrepencies in various apis

src/constants/providers.ts

+50-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
import { MODE, type Mode } from "zod-stream"
1+
import { omit } from "@/lib"
2+
import OpenAI from "openai"
3+
import { z } from "zod"
4+
import { MODE, withResponseModel, type Mode } from "zod-stream"
25

36
export const PROVIDERS = {
47
OAI: "OAI",
@@ -24,6 +27,52 @@ export const NON_OAI_PROVIDER_URLS = {
2427
[PROVIDERS.OAI]: "api.openai.com"
2528
} as const
2629

30+
export const PROVIDER_PARAMS_TRANSFORMERS = {
31+
[PROVIDERS.ANYSCALE]: {
32+
[MODE.JSON_SCHEMA]: function removeAdditionalPropertiesKeyJSONSchema<
33+
T extends z.AnyZodObject,
34+
P extends OpenAI.ChatCompletionCreateParams
35+
>(params: ReturnType<typeof withResponseModel<T, "JSON_SCHEMA", P>>) {
36+
if ("additionalProperties" in params.response_format.schema) {
37+
return {
38+
...params,
39+
response_format: {
40+
...params.response_format,
41+
schema: omit(["additionalProperties"], params.response_format.schema)
42+
}
43+
}
44+
}
45+
46+
return params
47+
},
48+
[MODE.TOOLS]: function removeAdditionalPropertiesKeyTools<
49+
T extends z.AnyZodObject,
50+
P extends OpenAI.ChatCompletionCreateParams
51+
>(params: ReturnType<typeof withResponseModel<T, "TOOLS", P>>) {
52+
if (params.tools.some(tool => tool.function?.parameters)) {
53+
return {
54+
...params,
55+
tools: params.tools.map(tool => {
56+
if (tool.function?.parameters) {
57+
return {
58+
...tool,
59+
function: {
60+
...tool.function,
61+
parameters: omit(["additionalProperties"], tool.function.parameters)
62+
}
63+
}
64+
}
65+
66+
return tool
67+
})
68+
}
69+
}
70+
71+
return params
72+
}
73+
}
74+
} as const
75+
2776
export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
2877
[PROVIDERS.OTHER]: {
2978
[MODE.FUNCTIONS]: ["*"],

src/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
import Instructor from "./instructor"
22

3+
export * from "./types"
34
export default Instructor

src/instructor.ts

+58-17
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,18 @@ import {
66
} from "@/types"
77
import OpenAI from "openai"
88
import { z } from "zod"
9-
import ZodStream, {
10-
CompletionMeta,
11-
OAIResponseParser,
12-
OAIStream,
13-
withResponseModel,
14-
type Mode
15-
} from "zod-stream"
9+
import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream"
1610
import { fromZodError } from "zod-validation-error"
1711

1812
import {
1913
NON_OAI_PROVIDER_URLS,
2014
Provider,
15+
PROVIDER_PARAMS_TRANSFORMERS,
2116
PROVIDER_SUPPORTED_MODES,
2217
PROVIDER_SUPPORTED_MODES_BY_MODEL,
2318
PROVIDERS
2419
} from "./constants/providers"
20+
import { CompletionMeta } from "./types"
2521

2622
const MAX_RETRIES_DEFAULT = 0
2723

@@ -109,7 +105,9 @@ class Instructor {
109105
let validationIssues = ""
110106
let lastMessage: OpenAI.ChatCompletionMessageParam | null = null
111107

112-
const completionParams = withResponseModel({
108+
const paramsTransformer = PROVIDER_PARAMS_TRANSFORMERS?.[this.provider]?.[this.mode]
109+
110+
let completionParams = withResponseModel({
113111
params: {
114112
...params,
115113
stream: false
@@ -118,6 +116,10 @@ class Instructor {
118116
response_model
119117
})
120118

119+
if (!!paramsTransformer) {
120+
completionParams = paramsTransformer(completionParams)
121+
}
122+
121123
const makeCompletionCall = async () => {
122124
let resolvedParams = completionParams
123125

@@ -135,17 +137,33 @@ class Instructor {
135137
}
136138
}
137139

138-
this.log("debug", response_model.name, "making completion call with params: ", resolvedParams)
140+
let completion: OpenAI.Chat.Completions.ChatCompletion | null = null
139141

140-
const completion = await this.client.chat.completions.create(resolvedParams)
142+
try {
143+
completion = await this.client.chat.completions.create(resolvedParams)
144+
this.log("debug", "raw standard completion response: ", completion)
145+
} catch (error) {
146+
this.log(
147+
"error",
148+
`Error making completion call - mode: ${this.mode} | Client base URL: ${this.client.baseURL} | with params:`,
149+
resolvedParams,
150+
`raw error`,
151+
error
152+
)
153+
154+
throw error
155+
}
141156

142157
const parsedCompletion = OAIResponseParser(
143158
completion as OpenAI.Chat.Completions.ChatCompletion
144159
)
160+
145161
try {
146-
return JSON.parse(parsedCompletion) as z.infer<T>
162+
const data = JSON.parse(parsedCompletion) as z.infer<T> & { _meta?: CompletionMeta }
163+
return { ...data, _meta: { usage: completion?.usage ?? undefined } }
147164
} catch (error) {
148165
this.log("error", "failed to parse completion", parsedCompletion, this.mode)
166+
throw error
149167
}
150168
}
151169

@@ -173,13 +191,29 @@ class Instructor {
173191
return validation.data
174192
} catch (error) {
175193
if (attempts < max_retries) {
176-
this.log("debug", response_model.name, "Retrying, attempt: ", attempts)
177-
this.log("warn", response_model.name, "Validation error: ", validationIssues)
194+
this.log(
195+
"debug",
196+
`response model: ${response_model.name} - Retrying, attempt: `,
197+
attempts
198+
)
199+
this.log(
200+
"warn",
201+
`response model: ${response_model.name} - Validation issues: `,
202+
validationIssues
203+
)
178204
attempts++
179205
return await makeCompletionCallWithRetries()
180206
} else {
181-
this.log("debug", response_model.name, "Max attempts reached: ", attempts)
182-
this.log("error", response_model.name, "Error: ", validationIssues)
207+
this.log(
208+
"debug",
209+
`response model: ${response_model.name} - Max attempts reached: ${attempts}`
210+
)
211+
this.log(
212+
"error",
213+
`response model: ${response_model.name} - Validation issues: `,
214+
validationIssues
215+
)
216+
183217
throw error
184218
}
185219
}
@@ -193,13 +227,15 @@ class Instructor {
193227
response_model,
194228
...params
195229
}: ChatCompletionCreateParamsWithModel<T>): Promise<
196-
AsyncGenerator<Partial<T> & { _meta: CompletionMeta }, void, unknown>
230+
AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown>
197231
> {
198232
if (max_retries) {
199233
this.log("warn", "max_retries is not supported for streaming completions")
200234
}
201235

202-
const completionParams = withResponseModel({
236+
const paramsTransformer = PROVIDER_PARAMS_TRANSFORMERS?.[this.provider]?.[this.mode]
237+
238+
let completionParams = withResponseModel({
203239
params: {
204240
...params,
205241
stream: true
@@ -208,13 +244,18 @@ class Instructor {
208244
mode: this.mode
209245
})
210246

247+
if (paramsTransformer) {
248+
completionParams = paramsTransformer(completionParams)
249+
}
250+
211251
const streamClient = new ZodStream({
212252
debug: this.debug ?? false
213253
})
214254

215255
return streamClient.create({
216256
completionPromise: async () => {
217257
const completion = await this.client.chat.completions.create(completionParams)
258+
this.log("debug", "raw stream completion response: ", completion)
218259

219260
return OAIStream({
220261
res: completion

src/types/index.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ import OpenAI from "openai"
22
import { Stream } from "openai/streaming"
33
import { z } from "zod"
44
import {
5-
CompletionMeta,
5+
CompletionMeta as ZCompletionMeta,
66
type Mode as ZMode,
77
type ResponseModel as ZResponseModel
88
} from "zod-stream"
99

1010
export type LogLevel = "debug" | "info" | "warn" | "error"
11-
11+
export type CompletionMeta = Partial<ZCompletionMeta> & {
12+
usage?: OpenAI.CompletionUsage
13+
}
1214
export type Mode = ZMode
1315
export type ResponseModel<T extends z.AnyZodObject> = ZResponseModel<T>
1416

@@ -37,7 +39,8 @@ export type ReturnTypeBasedOnParams<P> =
3739
response_model: ResponseModel<infer T>
3840
}
3941
) ?
40-
Promise<AsyncGenerator<Partial<z.infer<T>> & { _meta: CompletionMeta }, void, unknown>>
41-
: P extends { response_model: ResponseModel<infer T> } ? Promise<z.infer<T>>
42+
Promise<AsyncGenerator<Partial<z.infer<T>> & { _meta?: CompletionMeta }, void, unknown>>
43+
: P extends { response_model: ResponseModel<infer T> } ?
44+
Promise<z.infer<T> & { _meta?: CompletionMeta }>
4245
: P extends { stream: true } ? Stream<OpenAI.Chat.Completions.ChatCompletionChunk>
4346
: OpenAI.Chat.Completions.ChatCompletion

tests/inference.test.ts

+9-5
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
// 6. response_model, no stream, max_retries
88

99
import Instructor from "@/instructor"
10+
import { type CompletionMeta } from "@/types"
1011
import { describe, expect, test } from "bun:test"
1112
import OpenAI from "openai"
1213
import { Stream } from "openai/streaming"
1314
import { type } from "ts-inference-check"
1415
import { z } from "zod"
15-
import { CompletionMeta } from "zod-stream"
1616

1717
describe("Inference Checking", () => {
1818
const UserSchema = z.object({
@@ -61,7 +61,9 @@ describe("Inference Checking", () => {
6161
stream: false
6262
})
6363

64-
expect(type(user).strictly.is<z.infer<typeof UserSchema>>(true)).toBe(true)
64+
expect(
65+
type(user).strictly.is<z.infer<typeof UserSchema> & { _meta?: CompletionMeta }>(true)
66+
).toBe(true)
6567
})
6668

6769
test("response_model, stream", async () => {
@@ -79,7 +81,7 @@ describe("Inference Checking", () => {
7981
Partial<{
8082
name: string
8183
age: number
82-
}> & { _meta: CompletionMeta },
84+
}> & { _meta?: CompletionMeta },
8385
void,
8486
unknown
8587
>
@@ -103,7 +105,7 @@ describe("Inference Checking", () => {
103105
Partial<{
104106
name: string
105107
age: number
106-
}> & { _meta: CompletionMeta },
108+
}> & { _meta?: CompletionMeta },
107109
void,
108110
unknown
109111
>
@@ -120,6 +122,8 @@ describe("Inference Checking", () => {
120122
max_retries: 3
121123
})
122124

123-
expect(type(user).strictly.is<z.infer<typeof UserSchema>>(true)).toBe(true)
125+
expect(
126+
type(user).strictly.is<z.infer<typeof UserSchema> & { _meta?: CompletionMeta }>(true)
127+
).toBe(true)
124128
})
125129
})

tests/mode.test.ts

+1-2
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ async function extractUser(model: string, mode: Mode, provider: Provider) {
9696
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
9797
model: model,
9898
response_model: { schema: UserSchema, name: "User" },
99-
max_retries: 4,
100-
seed: provider === PROVIDERS.OAI ? 1 : undefined
99+
max_retries: 4
101100
})
102101

103102
return user

0 commit comments

Comments
 (0)