Skip to content

Commit c9ab910

Browse files
authored
providers + prep (#99)
1 parent c7aec7c commit c9ab910

File tree

13 files changed

+201
-293
lines changed

13 files changed

+201
-293
lines changed

.changeset/selfish-birds-appear.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@instructor-ai/instructor": patch
3+
---
4+
5+
Adding explicit support for non-oai providers - currently anyscale and together ai - will do explicit checks on mode selected vs provider and model

.example.env

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
OPENAI_API_KEY=
2-
ANYSCALE_API_KEY=
2+
ANYSCALE_API_KEY=
3+
TOGETHER_API_KEY=

.github/workflows/test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ jobs:
1414
env:
1515
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
1616
ANYSCALE_API_KEY: ${{ secrets.ANYSCALE_API_KEY }}
17+
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
1718

1819
steps:
1920
- uses: actions/checkout@v3

bun.lockb

-64 Bytes
Binary file not shown.

package.json

+1-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@
5151
},
5252
"homepage": "https://github.com/instructor-ai/instructor-js#readme",
5353
"dependencies": {
54-
"zod-stream": "^0.0.5",
55-
"zod-to-json-schema": "^3.22.3",
54+
"zod-stream": "0.0.6",
5655
"zod-validation-error": "^2.1.0"
5756
},
5857
"peerDependencies": {

src/constants/modes.ts

-35
This file was deleted.

src/constants/providers.ts

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import { MODE, type Mode } from "zod-stream"
2+
3+
export const PROVIDERS = {
4+
OAI: "OAI",
5+
ANYSCALE: "ANYSCALE",
6+
TOGETHER: "TOGETHER",
7+
OTHER: "OTHER"
8+
} as const
9+
10+
export type Provider = keyof typeof PROVIDERS
11+
12+
export const PROVIDER_SUPPORTED_MODES: {
13+
[key in Provider]: Mode[]
14+
} = {
15+
[PROVIDERS.OTHER]: [MODE.FUNCTIONS, MODE.TOOLS, MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA],
16+
[PROVIDERS.OAI]: [MODE.FUNCTIONS, MODE.TOOLS, MODE.JSON, MODE.MD_JSON],
17+
[PROVIDERS.ANYSCALE]: [MODE.TOOLS, MODE.JSON, MODE.JSON_SCHEMA],
18+
[PROVIDERS.TOGETHER]: [MODE.TOOLS, MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA]
19+
} as const
20+
21+
export const NON_OAI_PROVIDER_URLS = {
22+
[PROVIDERS.ANYSCALE]: "api.endpoints.anyscale",
23+
[PROVIDERS.TOGETHER]: "api.together.xyz",
24+
[PROVIDERS.OAI]: "api.openai.com"
25+
} as const
26+
27+
export const PROVIDER_SUPPORTED_MODES_BY_MODEL = {
28+
[PROVIDERS.OTHER]: {
29+
[MODE.FUNCTIONS]: ["*"],
30+
[MODE.TOOLS]: ["*"],
31+
[MODE.JSON]: ["*"],
32+
[MODE.MD_JSON]: ["*"],
33+
[MODE.JSON_SCHEMA]: ["*"]
34+
},
35+
[PROVIDERS.OAI]: {
36+
[MODE.FUNCTIONS]: ["*"],
37+
[MODE.TOOLS]: ["*"],
38+
[MODE.JSON]: [
39+
"gpt-3.5-turbo-1106",
40+
"gpt-4-1106-preview",
41+
"gpt-4-0125-preview",
42+
"gpt-4-turbo-preview"
43+
],
44+
[MODE.MD_JSON]: ["*"]
45+
},
46+
[PROVIDERS.TOGETHER]: {
47+
[MODE.JSON_SCHEMA]: [
48+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
49+
"mistralai/Mistral-7B-Instruct-v0.1",
50+
"togethercomputer/CodeLlama-34b-Instruct"
51+
],
52+
[MODE.MD_JSON]: ["*"],
53+
[MODE.TOOLS]: [
54+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
55+
"mistralai/Mistral-7B-Instruct-v0.1",
56+
"togethercomputer/CodeLlama-34b-Instruct"
57+
]
58+
},
59+
[PROVIDERS.ANYSCALE]: {
60+
[MODE.JSON_SCHEMA]: [
61+
"mistralai/Mistral-7B-Instruct-v0.1",
62+
"mistralai/Mixtral-8x7B-Instruct-v0.1"
63+
],
64+
[MODE.MD_JSON]: ["*"],
65+
[MODE.TOOLS]: ["mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"]
66+
}
67+
}

src/dsl/validator.ts

+2-4
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ import { OAIClientExtended } from "@/instructor"
22
import type { ChatCompletionCreateParams } from "openai/resources/chat/completions.mjs"
33
import { RefinementCtx, z } from "zod"
44

5-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
6-
type AsyncSuperRefineFunction = (data: any, ctx: RefinementCtx) => Promise<any>
5+
type AsyncSuperRefineFunction = (data: string, ctx: RefinementCtx) => Promise<void>
76

87
export const LLMValidator = (
98
instructor: OAIClientExtended,
@@ -15,7 +14,7 @@ export const LLMValidator = (
1514
reason: z.string().optional()
1615
})
1716

18-
const fn = async (value, ctx) => {
17+
return async (value, ctx) => {
1918
const validated = await instructor.chat.completions.create({
2019
max_retries: 0,
2120
...params,
@@ -41,5 +40,4 @@ export const LLMValidator = (
4140
})
4241
}
4342
}
44-
return fn
4543
}

src/instructor.ts

+47-10
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,27 @@ import {
22
ChatCompletionCreateParamsWithModel,
33
InstructorConfig,
44
LogLevel,
5-
Mode,
65
ReturnTypeBasedOnParams
76
} from "@/types"
87
import OpenAI from "openai"
98
import { z } from "zod"
10-
import ZodStream, { OAIStream, withResponseModel } from "zod-stream"
9+
import ZodStream, { OAIResponseParser, OAIStream, withResponseModel, type Mode } from "zod-stream"
1110
import { fromZodError } from "zod-validation-error"
1211

13-
import { MODE, MODE_TO_PARSER } from "@/constants/modes"
12+
import {
13+
NON_OAI_PROVIDER_URLS,
14+
Provider,
15+
PROVIDER_SUPPORTED_MODES,
16+
PROVIDER_SUPPORTED_MODES_BY_MODEL,
17+
PROVIDERS
18+
} from "./constants/providers"
1419

1520
const MAX_RETRIES_DEFAULT = 0
1621

1722
class Instructor {
1823
readonly client: OpenAI
1924
readonly mode: Mode
25+
readonly provider: Provider
2026
readonly debug: boolean = false
2127

2228
/**
@@ -29,11 +35,39 @@ class Instructor {
2935
this.mode = mode
3036
this.debug = debug
3137

32-
//TODO: probably some more sophisticated validation we can do here re: modes and otherwise.
33-
// but just throwing quick here for now.
34-
if (mode === MODE.JSON_SCHEMA) {
35-
if (!this.client.baseURL.includes("anyscale")) {
36-
throw new Error("JSON_SCHEMA mode is only support on Anyscale.")
38+
const provider =
39+
this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.ANYSCALE) ? PROVIDERS.ANYSCALE
40+
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.TOGETHER) ? PROVIDERS.TOGETHER
41+
: this.client?.baseURL.includes(NON_OAI_PROVIDER_URLS.TOGETHER) ? PROVIDERS.OAI
42+
: PROVIDERS.OTHER
43+
44+
this.provider = provider
45+
46+
this.validateOptions()
47+
}
48+
49+
private validateOptions() {
50+
const isModeSupported = PROVIDER_SUPPORTED_MODES[this.provider].includes(this.mode)
51+
52+
if (this.provider === PROVIDERS.OTHER) {
53+
this.log("debug", "Unknown provider - cant validate options.")
54+
}
55+
56+
if (!isModeSupported) {
57+
throw new Error(`Mode ${this.mode} is not supported by provider ${this.provider}`)
58+
}
59+
}
60+
61+
private validateModelModeSupport<T extends z.AnyZodObject>(
62+
params: ChatCompletionCreateParamsWithModel<T>
63+
) {
64+
if (this.provider !== PROVIDERS.OAI) {
65+
const modelSupport = PROVIDER_SUPPORTED_MODES_BY_MODEL[this.provider][this.mode]
66+
67+
if (!modelSupport.includes("*") && !modelSupport.includes(params.model)) {
68+
throw new Error(
69+
`Model ${params.model} is not supported by provider ${this.provider} in mode ${this.mode}`
70+
)
3771
}
3872
}
3973
}
@@ -98,9 +132,10 @@ class Instructor {
98132
this.log("debug", response_model.name, "making completion call with params: ", resolvedParams)
99133

100134
const completion = await this.client.chat.completions.create(resolvedParams)
101-
const parser = MODE_TO_PARSER[this.mode]
102135

103-
const parsedCompletion = parser(completion as OpenAI.Chat.Completions.ChatCompletion)
136+
const parsedCompletion = OAIResponseParser(
137+
completion as OpenAI.Chat.Completions.ChatCompletion
138+
)
104139
try {
105140
return JSON.parse(parsedCompletion) as z.infer<T>
106141
} catch (error) {
@@ -200,6 +235,8 @@ class Instructor {
200235
>(
201236
params: P
202237
): Promise<ReturnTypeBasedOnParams<P>> => {
238+
this.validateModelModeSupport(params)
239+
203240
if (this.isChatCompletionCreateParamsWithModel(params)) {
204241
if (params.stream) {
205242
return this.chatCompletionStream(params) as ReturnTypeBasedOnParams<

src/oai/params.ts

-100
This file was deleted.

0 commit comments

Comments
 (0)