Skip to content

Commit 6942d65

Browse files
authored
allow request option pass through (#164)
1 parent 224a0d1 commit 6942d65

22 files changed

+268
-96
lines changed

.changeset/curly-ants-tie.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@instructor-ai/instructor": minor
3+
---
4+
5+
adding request option pass through + handling non validation errors a little bit better and not retrying if not validation error specifically

docs/concepts/streaming.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ A follow-up meeting is scheduled for January 25th at 3 PM GMT to finalize the ag
6161

6262
const extractionStream = await client.chat.completions.create({
6363
messages: [{ role: "user", content: textBlock }],
64-
model: "gpt-4-1106-preview",
64+
model: "gpt-4-turbo",
6565
response_model: {
6666
schema: ExtractionValuesSchema,
6767
name: "value extraction"

docs/examples/action_items.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ const extractActionItems = async (data: string): Promise<ActionItems | undefined
6666
"content": `Create the action items for the following transcript: ${data}`,
6767
},
6868
],
69-
model: "gpt-4-1106-preview",
69+
model: "gpt-4-turbo",
7070
response_model: { schema: ActionItemsSchema },
7171
max_tokens: 1000,
7272
temperature: 0.0,

docs/examples/query_decomposition.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ const createQueryPlan = async (question: string): Promise<QueryPlan | undefined>
6565
"content": `Consider: ${question}\nGenerate the correct query plan.`,
6666
},
6767
],
68-
model: "gpt-4-1106-preview",
68+
model: "gpt-4-turbo",
6969
response_model: { schema: QueryPlanSchema },
7070
max_tokens: 1000,
7171
temperature: 0.0,

examples/action_items/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ const extractActionItems = async (data: string) => {
4545
content: `Create the action items for the following transcript: ${data}`
4646
}
4747
],
48-
model: "gpt-4-1106-preview",
48+
model: "gpt-4-turbo",
4949
response_model: { schema: ActionItemsSchema, name: "ActionItems" },
5050
max_tokens: 1000,
5151
temperature: 0.0,

examples/extract_user_stream/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ let extraction = {}
5353

5454
const extractionStream = await client.chat.completions.create({
5555
messages: [{ role: "user", content: textBlock }],
56-
model: "gpt-4-1106-preview",
56+
model: "gpt-4-turbo",
5757
response_model: {
5858
schema: ExtractionValuesSchema,
5959
name: "value extraction"

examples/llm-validator/index.ts

+3-4
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ const openAi = new OpenAI({ apiKey: process.env.OPENAI_API_KEY ?? "" })
77

88
const instructor = Instructor({
99
client: openAi,
10-
mode: "TOOLS",
11-
debug: true
10+
mode: "TOOLS"
1211
})
1312

1413
const statement = "Do not say questionable things"
@@ -17,7 +16,7 @@ const QuestionAnswer = z.object({
1716
question: z.string(),
1817
answer: z.string().superRefine(
1918
LLMValidator(instructor, statement, {
20-
model: "gpt-4"
19+
model: "gpt-4-turbo"
2120
})
2221
)
2322
})
@@ -26,7 +25,7 @@ const question = "What is the meaning of life?"
2625

2726
const check = async (context: string) => {
2827
return await instructor.chat.completions.create({
29-
model: "gpt-3.5-turbo",
28+
model: "gpt-4-turbo",
3029
max_retries: 2,
3130
response_model: { schema: QuestionAnswer, name: "Question and Answer" },
3231
messages: [

examples/query_decomposition/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ const createQueryPlan = async (question: string) => {
3838
content: `Consider: ${question}\nGenerate the correct query plan.`
3939
}
4040
],
41-
model: "gpt-4-1106-preview",
41+
model: "gpt-4-turbo",
4242
response_model: { schema: QueryPlanSchema, name: "Query Plan Decomposition" },
4343
max_tokens: 1000,
4444
temperature: 0.0,

package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@instructor-ai/instructor",
3-
"version": "1.1.2",
3+
"version": "1.1.1",
44
"description": "structured outputs for llms",
55
"publishConfig": {
66
"access": "public"

src/constants/providers.ts

+17-9
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import { omit } from "@/lib"
22
import OpenAI from "openai"
33
import { z } from "zod"
4-
import { MODE, withResponseModel, type Mode } from "zod-stream"
4+
import { withResponseModel, MODE as ZMODE, type Mode } from "zod-stream"
55

6+
export const MODE = ZMODE
67
export const PROVIDERS = {
78
OAI: "OAI",
89
ANYSCALE: "ANYSCALE",
@@ -11,7 +12,6 @@ export const PROVIDERS = {
1112
GROQ: "GROQ",
1213
OTHER: "OTHER"
1314
} as const
14-
1515
export type Provider = keyof typeof PROVIDERS
1616

1717
export const PROVIDER_SUPPORTED_MODES: {
@@ -34,6 +34,19 @@ export const NON_OAI_PROVIDER_URLS = {
3434
} as const
3535

3636
export const PROVIDER_PARAMS_TRANSFORMERS = {
37+
[PROVIDERS.GROQ]: {
38+
[MODE.TOOLS]: function groqToolsParamsTransformer<
39+
T extends z.AnyZodObject,
40+
P extends OpenAI.ChatCompletionCreateParams
41+
>(params: ReturnType<typeof withResponseModel<T, "TOOLS", P>>) {
42+
if (params.tools.some(tool => tool) && params.stream) {
43+
console.warn("Streaming may not be supported when using tools in Groq, try MD_JSON instead")
44+
return params
45+
}
46+
47+
return params
48+
}
49+
},
3750
[PROVIDERS.ANYSCALE]: {
3851
[MODE.JSON_SCHEMA]: function removeAdditionalPropertiesKeyJSONSchema<
3952
T extends z.AnyZodObject,
@@ -90,12 +103,7 @@ export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
90103
[PROVIDERS.OAI]: {
91104
[MODE.FUNCTIONS]: ["*"],
92105
[MODE.TOOLS]: ["*"],
93-
[MODE.JSON]: [
94-
"gpt-3.5-turbo-1106",
95-
"gpt-4-1106-preview",
96-
"gpt-4-0125-preview",
97-
"gpt-4-turbo-preview"
98-
],
106+
[MODE.JSON]: ["gpt-3.5-turbo-1106", "gpt-4-turbo", "gpt-4-0125-preview", "gpt-4-turbo-preview"],
99107
[MODE.MD_JSON]: ["*"]
100108
},
101109
[PROVIDERS.TOGETHER]: {
@@ -124,7 +132,7 @@ export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
124132
[MODE.TOOLS]: ["*"]
125133
},
126134
[PROVIDERS.GROQ]: {
127-
[MODE.TOOLS]: ["llama2-70b-4096", "mixtral-8x7b-32768", "gemma-7b-it"],
135+
[MODE.TOOLS]: ["mixtral-8x7b-32768", "gemma-7b-it"],
128136
[MODE.MD_JSON]: ["*"]
129137
}
130138
}

src/dsl/validator.ts

+1-7
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,9 @@ export const LLMValidator = <C extends GenericClient | OpenAI>(
4444
}
4545
}
4646

47-
export const moderationValidator = <C extends GenericClient | OpenAI>(
48-
client: InstructorClient<C>
49-
) => {
47+
export const moderationValidator = (client: InstructorClient<OpenAI>) => {
5048
return async (value: string, ctx: z.RefinementCtx) => {
5149
try {
52-
if (!(client instanceof OpenAI)) {
53-
throw new Error("ModerationValidator only supports OpenAI clients")
54-
}
55-
5650
const response = await client.moderations.create({ input: value })
5751
const flaggedResults = response.results.filter(result => result.flagged)
5852

src/instructor.ts

+60-29
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import {
22
ChatCompletionCreateParamsWithModel,
3+
ClientTypeChatCompletionRequestOptions,
34
GenericChatCompletion,
45
GenericClient,
56
InstructorConfig,
@@ -8,7 +9,7 @@ import {
89
ReturnTypeBasedOnParams
910
} from "@/types"
1011
import OpenAI from "openai"
11-
import { z } from "zod"
12+
import { z, ZodError } from "zod"
1213
import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream"
1314
import { fromZodError } from "zod-validation-error"
1415

@@ -102,11 +103,14 @@ class Instructor<C extends GenericClient | OpenAI> {
102103
}
103104
}
104105

105-
private async chatCompletionStandard<T extends z.AnyZodObject>({
106-
max_retries = MAX_RETRIES_DEFAULT,
107-
response_model,
108-
...params
109-
}: ChatCompletionCreateParamsWithModel<T>): Promise<z.infer<T>> {
106+
private async chatCompletionStandard<T extends z.AnyZodObject>(
107+
{
108+
max_retries = MAX_RETRIES_DEFAULT,
109+
response_model,
110+
...params
111+
}: ChatCompletionCreateParamsWithModel<T>,
112+
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
113+
): Promise<z.infer<T>> {
110114
let attempts = 0
111115
let validationIssues = ""
112116
let lastMessage: OpenAI.ChatCompletionMessageParam | null = null
@@ -147,13 +151,17 @@ class Instructor<C extends GenericClient | OpenAI> {
147151

148152
try {
149153
if (this.client.chat?.completions?.create) {
150-
const result = await this.client.chat.completions.create({
151-
...resolvedParams,
152-
stream: false
153-
})
154+
const result = await this.client.chat.completions.create(
155+
{
156+
...resolvedParams,
157+
stream: false
158+
},
159+
requestOptions
160+
)
161+
154162
completion = result as GenericChatCompletion<typeof result>
155163
} else {
156-
throw new Error("Unsupported client type")
164+
throw new Error("Unsupported client type -- no completion method found.")
157165
}
158166
this.log("debug", "raw standard completion response: ", completion)
159167
} catch (error) {
@@ -176,7 +184,17 @@ class Instructor<C extends GenericClient | OpenAI> {
176184
const data = JSON.parse(parsedCompletion) as z.infer<T> & { _meta?: CompletionMeta }
177185
return { ...data, _meta: { usage: completion?.usage ?? undefined } }
178186
} catch (error) {
179-
this.log("error", "failed to parse completion", parsedCompletion, this.mode)
187+
this.log(
188+
"error",
189+
"failed to parse completion",
190+
parsedCompletion,
191+
this.mode,
192+
"attempt: ",
193+
attempts,
194+
"max attempts: ",
195+
max_retries
196+
)
197+
180198
throw error
181199
}
182200
}
@@ -202,26 +220,38 @@ class Instructor<C extends GenericClient | OpenAI> {
202220
throw new Error("Validation failed.")
203221
}
204222
}
223+
205224
return validation.data
206225
} catch (error) {
226+
if (!(error instanceof ZodError)) {
227+
throw error
228+
}
229+
207230
if (attempts < max_retries) {
208231
this.log(
209232
"debug",
210233
`response model: ${response_model.name} - Retrying, attempt: `,
211234
attempts
212235
)
236+
213237
this.log(
214238
"warn",
215239
`response model: ${response_model.name} - Validation issues: `,
216-
validationIssues
240+
validationIssues,
241+
" - Attempt: ",
242+
attempts,
243+
" - Max attempts: ",
244+
max_retries
217245
)
246+
218247
attempts++
219248
return await makeCompletionCallWithRetries()
220249
} else {
221250
this.log(
222251
"debug",
223252
`response model: ${response_model.name} - Max attempts reached: ${attempts}`
224253
)
254+
225255
this.log(
226256
"error",
227257
`response model: ${response_model.name} - Validation issues: `,
@@ -236,13 +266,10 @@ class Instructor<C extends GenericClient | OpenAI> {
236266
return makeCompletionCallWithRetries()
237267
}
238268

239-
private async chatCompletionStream<T extends z.AnyZodObject>({
240-
max_retries,
241-
response_model,
242-
...params
243-
}: ChatCompletionCreateParamsWithModel<T>): Promise<
244-
AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown>
245-
> {
269+
private async chatCompletionStream<T extends z.AnyZodObject>(
270+
{ max_retries, response_model, ...params }: ChatCompletionCreateParamsWithModel<T>,
271+
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
272+
): Promise<AsyncGenerator<Partial<T> & { _meta?: CompletionMeta }, void, unknown>> {
246273
if (max_retries) {
247274
this.log("warn", "max_retries is not supported for streaming completions")
248275
}
@@ -269,10 +296,13 @@ class Instructor<C extends GenericClient | OpenAI> {
269296
return streamClient.create({
270297
completionPromise: async () => {
271298
if (this.client.chat?.completions?.create) {
272-
const completion = await this.client.chat.completions.create({
273-
...completionParams,
274-
stream: true
275-
})
299+
const completion = await this.client.chat.completions.create(
300+
{
301+
...completionParams,
302+
stream: true
303+
},
304+
requestOptions
305+
)
276306

277307
this.log("debug", "raw stream completion response: ", completion)
278308

@@ -306,18 +336,19 @@ class Instructor<C extends GenericClient | OpenAI> {
306336
P extends T extends z.AnyZodObject ? ChatCompletionCreateParamsWithModel<T>
307337
: ClientTypeChatCompletionParams<OpenAILikeClient<C>> & { response_model: never }
308338
>(
309-
params: P
339+
params: P,
340+
requestOptions?: ClientTypeChatCompletionRequestOptions<C>
310341
): Promise<ReturnTypeBasedOnParams<typeof this.client, P>> => {
311342
this.validateModelModeSupport(params)
312343

313344
if (this.isChatCompletionCreateParamsWithModel(params)) {
314345
if (params.stream) {
315-
return this.chatCompletionStream(params) as ReturnTypeBasedOnParams<
346+
return this.chatCompletionStream(params, requestOptions) as ReturnTypeBasedOnParams<
316347
typeof this.client,
317348
P & { stream: true }
318349
>
319350
} else {
320-
return this.chatCompletionStandard(params) as ReturnTypeBasedOnParams<
351+
return this.chatCompletionStandard(params, requestOptions) as ReturnTypeBasedOnParams<
321352
typeof this.client,
322353
P
323354
>
@@ -326,8 +357,8 @@ class Instructor<C extends GenericClient | OpenAI> {
326357
if (this.client.chat?.completions?.create) {
327358
const result =
328359
this.isStandardStream(params) ?
329-
await this.client.chat.completions.create(params)
330-
: await this.client.chat.completions.create(params)
360+
await this.client.chat.completions.create(params, requestOptions)
361+
: await this.client.chat.completions.create(params, requestOptions)
331362

332363
return result as unknown as ReturnTypeBasedOnParams<OpenAILikeClient<C>, P>
333364
} else {

src/types/index.ts

+7-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ export type GenericCreateParams<M = unknown> = Omit<
1818
[key: string]: unknown
1919
}
2020

21+
export type GenericRequestOptions = Partial<OpenAI.RequestOptions> & {
22+
[key: string]: unknown
23+
}
24+
2125
export type GenericChatCompletion<T = unknown> = Partial<OpenAI.Chat.Completions.ChatCompletion> & {
2226
[key: string]: unknown
2327
choices?: T
@@ -43,15 +47,16 @@ export type GenericClient = {
4347
export type ClientTypeChatCompletionParams<C> =
4448
C extends OpenAI ? OpenAI.ChatCompletionCreateParams : GenericCreateParams
4549

50+
export type ClientTypeChatCompletionRequestOptions<C> =
51+
C extends OpenAI ? OpenAI.RequestOptions : GenericRequestOptions
52+
4653
export type ClientType<C> =
4754
C extends OpenAI ? "openai"
4855
: C extends GenericClient ? "generic"
4956
: never
5057

5158
export type OpenAILikeClient<C> = C extends OpenAI ? OpenAI : C & GenericClient
52-
5359
export type SupportedInstructorClient = GenericClient | OpenAI
54-
5560
export type LogLevel = "debug" | "info" | "warn" | "error"
5661

5762
export type CompletionMeta = Partial<ZCompletionMeta> & {

0 commit comments

Comments
 (0)