Skip to content

Commit dc22633

Browse files
authored
Force Name Parameter (#66)
1 parent 1f67aea commit dc22633

File tree

18 files changed

+77
-59
lines changed

18 files changed

+77
-59
lines changed

.changeset/plenty-dots-visit.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@instructor-ai/instructor": patch
3+
---
4+
5+
Cleanup Types, make response_model.name required and rely on inference

examples/action_items/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ const extractActionItems = async (data: string) => {
4646
}
4747
],
4848
model: "gpt-4-1106-preview",
49-
response_model: { schema: ActionItemsSchema },
49+
response_model: { schema: ActionItemsSchema, name: "ActionItems" },
5050
max_tokens: 1000,
5151
temperature: 0.0,
5252
max_retries: 2,

examples/classification/multi_prediction/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const createClassification = async (data: string) => {
2727
const classification = await client.chat.completions.create({
2828
messages: [{ role: "user", content: `"Classify the following support ticket: ${data}` }],
2929
model: "gpt-3.5-turbo",
30-
response_model: { schema: MultiClassificationSchema },
30+
response_model: { schema: MultiClassificationSchema, name: "MultiClassification" },
3131
max_retries: 3,
3232
seed: 1
3333
})

examples/classification/simple_prediction/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ const createClassification = async (data: string) => {
2626
const classification = await client.chat.completions.create({
2727
messages: [{ role: "user", content: `"Classify the following text: ${data}` }],
2828
model: "gpt-3.5-turbo",
29-
response_model: { schema: SimpleClassificationSchema },
29+
response_model: { schema: SimpleClassificationSchema, name: "SimpleClassification" },
3030
max_retries: 3,
3131
seed: 1
3232
})

examples/extract_user/anyscale.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ const client = Instructor({
2828
const user = await client.chat.completions.create({
2929
messages: [{ role: "user", content: "Harry Potter" }],
3030
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
31-
response_model: { schema: UserSchema },
31+
response_model: { schema: UserSchema, name: "User" },
3232
max_retries: 3
3333
})
3434

examples/extract_user/index.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ const user = await client.chat.completions.create({
2121
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
2222
model: "gpt-4",
2323
response_model: {
24-
schema: UserSchema
24+
schema: UserSchema,
25+
name: "User"
2526
},
2627
max_retries: 3,
2728
seed: 1

examples/extract_user/properties.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ const client = Instructor({
2828
const user = await client.chat.completions.create({
2929
messages: [{ role: "user", content: "Happy Potter" }],
3030
model: "gpt-4",
31-
response_model: { schema: UserSchema },
31+
response_model: { schema: UserSchema, name: "User" },
3232
max_retries: 3,
3333
seed: 1
3434
})

examples/knowledge-graph/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const createGraph = async (input: string) => {
3939
}
4040
],
4141
model: "gpt-3.5-turbo-1106",
42-
response_model: { schema: KnowledgeGraphSchema },
42+
response_model: { schema: KnowledgeGraphSchema, name: "Knowledge Graph" },
4343
max_retries: 5,
4444
seed: 1
4545
})

examples/passthrough/index.ts

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ const client = Instructor({
1515
const completion = (await client.chat.completions.create({
1616
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
1717
model: "gpt-3.5-turbo",
18-
max_retries: 3,
1918
seed: 1
2019
})) satisfies OpenAI.Chat.ChatCompletion
2120

examples/query_decomposition/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const createQueryPlan = async (question: string) => {
3939
}
4040
],
4141
model: "gpt-4-1106-preview",
42-
response_model: { schema: QueryPlanSchema },
42+
response_model: { schema: QueryPlanSchema, name: "Query Plan Decomposition" },
4343
max_tokens: 1000,
4444
temperature: 0.0,
4545
max_retries: 2,

examples/resolving-complex-entitities/index.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ const askAi = async (input: string) => {
6060
}
6161
],
6262
model: "gpt-4",
63-
response_model: { schema: DocumentExtractionSchema },
63+
response_model: { schema: DocumentExtractionSchema, name: "Document Extraction" },
6464
max_retries: 3,
6565
seed: 1
6666
})

src/instructor.ts

+21-22
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,19 @@ const MODE_TO_PARAMS = {
3838
[MODE.JSON_SCHEMA]: OAIBuildMessageBasedParams
3939
}
4040

41-
type ResponseModel<T> = {
41+
type ResponseModel<T extends z.ZodTypeAny> = {
4242
schema: T
43-
name?: string
43+
name: string
4444
description?: string
4545
}
4646

47-
type InstructorChatCompletionParams<T> = {
47+
type InstructorChatCompletionParams<T extends z.ZodTypeAny> = {
4848
response_model: ResponseModel<T>
4949
max_retries?: number
5050
}
5151

52-
type ChatCompletionCreateParamsWithModel<T extends z.ZodTypeAny> = ChatCompletionCreateParams &
53-
InstructorChatCompletionParams<T>
52+
export type ChatCompletionCreateParamsWithModel<T extends z.ZodTypeAny> =
53+
InstructorChatCompletionParams<T> & ChatCompletionCreateParams
5454

5555
type ReturnTypeBasedOnParams<P> = P extends ChatCompletionCreateParamsWithModel<infer T>
5656
? P extends { stream: true }
@@ -96,17 +96,17 @@ class Instructor {
9696
* @param {ChatCompletionCreateParamsWithModel} params - The parameters for chat completion.
9797
* @returns {Promise<any>} The response from the chat completion.
9898
*/
99-
chatCompletion = async <T extends z.ZodTypeAny>({
99+
chatCompletion = <T extends z.ZodTypeAny>({
100100
max_retries = 3,
101101
...params
102-
}: ChatCompletionCreateParamsWithModel<T>): Promise<
103-
Promise<z.infer<T>> | AsyncGenerator<z.infer<T>, void, unknown>
102+
}: ChatCompletionCreateParamsWithModel<T>): ReturnTypeBasedOnParams<
103+
ChatCompletionCreateParamsWithModel<T>
104104
> => {
105105
let attempts = 0
106106
let validationIssues = ""
107107
let lastMessage: ChatCompletionMessageParam | null = null
108108

109-
const completionParams = this.buildChatCompletionParams({ ...params })
109+
const completionParams = this.buildChatCompletionParams(params)
110110

111111
const makeCompletionCall = async () => {
112112
let resolvedParams = completionParams
@@ -133,7 +133,7 @@ class Instructor {
133133

134134
if ("choices" in completion) {
135135
const parsedCompletion = parser(completion)
136-
return JSON.parse(parsedCompletion)
136+
return JSON.parse(parsedCompletion) as z.infer<T>
137137
} else {
138138
return OAIStream({ res: completion, parser })
139139
}
@@ -182,7 +182,7 @@ class Instructor {
182182
}
183183
}
184184

185-
return await makeCompletionCallWithRetries()
185+
return makeCompletionCallWithRetries()
186186
}
187187

188188
private async partialStreamResponse({ stream, schema }) {
@@ -243,10 +243,9 @@ class Instructor {
243243
* @returns {ChatCompletionCreateParams} The chat completion parameters.
244244
*/
245245
private buildChatCompletionParams = <T extends z.ZodTypeAny>({
246-
response_model,
246+
response_model: { name, schema, description },
247247
...params
248248
}: ChatCompletionCreateParamsWithModel<T>): ChatCompletionCreateParams => {
249-
const { schema, name = "response_model", description } = response_model
250249
const safeName = name.replace(/[^a-zA-Z0-9]/g, "_").replace(/\s/g, "_")
251250

252251
const { definitions } = zodToJsonSchema(schema, {
@@ -275,7 +274,7 @@ class Instructor {
275274
}
276275
}
277276

278-
chatCompletionWithoutModel = async (
277+
chatCompletionWithoutModel = (
279278
params: ChatCompletionCreateParams
280279
): Promise<
281280
Stream<OpenAI.Chat.Completions.ChatCompletionChunk> | OpenAI.Chat.Completions.ChatCompletion
@@ -286,18 +285,18 @@ class Instructor {
286285
public chat = {
287286
completions: {
288287
create: <
289-
P extends ChatCompletionCreateParamsWithModel<z.ZodTypeAny> | ChatCompletionCreateParams
288+
T extends z.ZodTypeAny | undefined,
289+
P extends T extends z.ZodTypeAny
290+
? ChatCompletionCreateParamsWithModel<T>
291+
: ChatCompletionCreateParams & { response_model: never }
290292
>(
291293
params: P
292294
): ReturnTypeBasedOnParams<P> => {
293-
if ("response_model" in params && params.response_model?.schema !== undefined) {
294-
return this.chatCompletion(
295-
params as ChatCompletionCreateParamsWithModel<z.ZodTypeAny>
296-
) as ReturnTypeBasedOnParams<P>
295+
if ("response_model" in params) {
296+
console.log(params.response_model.name)
297+
return this.chatCompletion(params) as ReturnTypeBasedOnParams<P>
297298
} else {
298-
return this.chatCompletionWithoutModel(
299-
params as ChatCompletionCreateParams
300-
) as ReturnTypeBasedOnParams<P>
299+
return this.chatCompletionWithoutModel(params) as ReturnTypeBasedOnParams<P>
301300
}
302301
}
303302
}

src/oai/params.ts

+28-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
1+
import { ChatCompletionCreateParamsWithModel } from "@/instructor"
12
import { omit } from "@/lib"
3+
import { ChatCompletionCreateParams } from "openai/resources/index.mjs"
4+
import { z } from "zod"
5+
import { JsonSchema7Type } from "zod-to-json-schema"
26

37
import { MODE } from "@/constants/modes"
48

5-
export function OAIBuildFunctionParams(definition, params) {
9+
type ParseParams = {
10+
name: string
11+
description?: string
12+
} & JsonSchema7Type
13+
14+
export function OAIBuildFunctionParams<T extends z.ZodTypeAny>(
15+
definition: ParseParams,
16+
params: Omit<ChatCompletionCreateParamsWithModel<T>, "response_model">
17+
): ChatCompletionCreateParams {
618
const { name, description, ...definitionParams } = definition
719

820
return {
@@ -21,7 +33,10 @@ export function OAIBuildFunctionParams(definition, params) {
2133
}
2234
}
2335

24-
export function OAIBuildToolFunctionParams(definition, params) {
36+
export function OAIBuildToolFunctionParams<T extends z.ZodTypeAny>(
37+
definition: ParseParams,
38+
params: Omit<ChatCompletionCreateParamsWithModel<T>, "response_model">
39+
): ChatCompletionCreateParams {
2540
const { name, description, ...definitionParams } = definition
2641

2742
return {
@@ -35,7 +50,7 @@ export function OAIBuildToolFunctionParams(definition, params) {
3550
type: "function",
3651
function: {
3752
name: name,
38-
description: description ?? undefined,
53+
description: description,
3954
parameters: definitionParams
4055
}
4156
},
@@ -44,7 +59,11 @@ export function OAIBuildToolFunctionParams(definition, params) {
4459
}
4560
}
4661

47-
export function OAIBuildMessageBasedParams(definition, params, mode) {
62+
export function OAIBuildMessageBasedParams<T extends z.ZodTypeAny>(
63+
definition: ParseParams,
64+
params: Omit<ChatCompletionCreateParamsWithModel<T>, "response_model">,
65+
mode: MODE // This type should be typeof MODE.JSON | typeof MODE.JSON_SCHEMA | typeof MODE.MD_JSON
66+
): ChatCompletionCreateParams {
4867
const MODE_SPECIFIC_CONFIGS = {
4968
[MODE.JSON]: {
5069
response_format: { type: "json_object" }
@@ -57,24 +76,25 @@ export function OAIBuildMessageBasedParams(definition, params, mode) {
5776
}
5877
}
5978

60-
const modeConfig = MODE_SPECIFIC_CONFIGS[mode] ?? {}
79+
const modeConfig = MODE_SPECIFIC_CONFIGS[mode]
6180

62-
return {
81+
const t = {
6382
...params,
6483
...modeConfig,
6584
messages: [
66-
...(params?.messages ?? []),
85+
...params.messages,
6786
{
6887
role: "system",
6988
content: `
7089
Given a user prompt, you will return fully valid JSON based on the following description and schema.
7190
You will return no other prose. You will take into account any descriptions or required parameters within the schema
7291
and return a valid and fully escaped JSON object that matches the schema and those instructions.
7392
74-
description: ${definition?.description}
93+
description: ${definition.description}
7594
json schema: ${JSON.stringify(definition)}
7695
`
7796
}
7897
]
7998
}
99+
return t
80100
}

tests/extract.test.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ async function extractUser() {
2222
const user = await client.chat.completions.create({
2323
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
2424
model: "gpt-3.5-turbo",
25-
response_model: { schema: UserSchema },
25+
response_model: { schema: UserSchema, name: "User" },
2626
seed: 1
2727
})
2828

@@ -50,7 +50,7 @@ async function extractUserValidated() {
5050
const user = await client.chat.completions.create({
5151
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
5252
model: "gpt-4",
53-
response_model: { schema: UserSchema },
53+
response_model: { schema: UserSchema, name: "User" },
5454
max_retries: 3,
5555
seed: 1
5656
})
@@ -83,7 +83,7 @@ async function extractUserMany() {
8383
const user = await client.chat.completions.create({
8484
messages: [{ role: "user", content: "Jason is 30 years old, Sarah is 12" }],
8585
model: "gpt-3.5-turbo",
86-
response_model: { schema: UsersSchema },
86+
response_model: { schema: UsersSchema, name: "Users" },
8787
max_retries: 3,
8888
seed: 1
8989
})

tests/functions.test.ts

+6-12
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ async function extractUser() {
99
name: z.string()
1010
})
1111

12-
type User = z.infer<typeof UserSchema>
13-
1412
const oai = new OpenAI({
1513
apiKey: process.env.OPENAI_API_KEY ?? undefined,
1614
organization: process.env.OPENAI_ORG_ID ?? undefined
@@ -21,10 +19,10 @@ async function extractUser() {
2119
mode: "FUNCTIONS"
2220
})
2321

24-
const user: User = await client.chat.completions.create({
22+
const user = await client.chat.completions.create({
2523
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
2624
model: "gpt-3.5-turbo",
27-
response_model: { schema: UserSchema },
25+
response_model: { schema: UserSchema, name: "User" },
2826
seed: 1
2927
})
3028

@@ -42,8 +40,6 @@ async function extractUserValidated() {
4240
.describe("The users name, all uppercase")
4341
})
4442

45-
type User = z.infer<typeof UserSchema>
46-
4743
const oai = new OpenAI({
4844
apiKey: process.env.OPENAI_API_KEY ?? undefined,
4945
organization: process.env.OPENAI_ORG_ID ?? undefined
@@ -54,10 +50,10 @@ async function extractUserValidated() {
5450
mode: "TOOLS"
5551
})
5652

57-
const user: User = await client.chat.completions.create({
53+
const user = await client.chat.completions.create({
5854
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
5955
model: "gpt-3.5-turbo",
60-
response_model: { schema: UserSchema },
56+
response_model: { schema: UserSchema, name: "User" },
6157
max_retries: 3,
6258
seed: 1
6359
})
@@ -77,8 +73,6 @@ async function extractUserMany() {
7773
})
7874
.describe("Correctly formatted list of users")
7975

80-
type Users = z.infer<typeof UsersSchema>
81-
8276
const oai = new OpenAI({
8377
apiKey: process.env.OPENAI_API_KEY ?? undefined,
8478
organization: process.env.OPENAI_ORG_ID ?? undefined
@@ -89,10 +83,10 @@ async function extractUserMany() {
8983
mode: "TOOLS"
9084
})
9185

92-
const user: Users = await client.chat.completions.create({
86+
const user = await client.chat.completions.create({
9387
messages: [{ role: "user", content: "Jason is 30 years old, Sarah is 12" }],
9488
model: "gpt-3.5-turbo",
95-
response_model: { schema: UsersSchema },
89+
response_model: { schema: UsersSchema, name: "Users" },
9690
max_retries: 3,
9791
seed: 1
9892
})

tests/maybe.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ async function maybeExtractUser(content: string) {
2525
const user = await client.chat.completions.create({
2626
messages: [{ role: "user", content: "Extract " + content }],
2727
model: "gpt-4",
28-
response_model: { schema: MaybeUserSchema },
28+
response_model: { schema: MaybeUserSchema, name: "User" },
2929
max_retries: 3,
3030
seed: 1
3131
})

0 commit comments

Comments
 (0)