Skip to content

Commit ab971ec

Browse files
add sagemaker provider (#1917)
Co-authored-by: Patrick Erichsen <[email protected]>
1 parent e43a250 commit ab971ec

File tree

18 files changed

+344
-0
lines changed

18 files changed

+344
-0
lines changed

binary/package-lock.json

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

core/config/types.ts

+1
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@ declare global {
507507
| "gemini"
508508
| "mistral"
509509
| "bedrock"
510+
| "sagemaker"
510511
| "deepinfra"
511512
| "flowise"
512513
| "groq"

core/control-plane/schema.ts

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const modelDescriptionSchema = z.object({
1313
"gemini",
1414
"mistral",
1515
"bedrock",
16+
"sagemaker",
1617
"cloudflare",
1718
"azure",
1819
]),

core/index.d.ts

+1
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ type ModelProvider =
593593
| "gemini"
594594
| "mistral"
595595
| "bedrock"
596+
| "sagemaker"
596597
| "deepinfra"
597598
| "flowise"
598599
| "groq"

core/llm/autodetect.ts

+3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ const PROVIDER_HANDLES_TEMPLATING: ModelProvider[] = [
4242
"msty",
4343
"anthropic",
4444
"bedrock",
45+
"sagemaker",
4546
"continue-proxy",
4647
"mistral",
4748
];
@@ -54,6 +55,7 @@ const PROVIDER_SUPPORTS_IMAGES: ModelProvider[] = [
5455
"msty",
5556
"anthropic",
5657
"bedrock",
58+
"sagemaker",
5759
"continue-proxy",
5860
];
5961

@@ -97,6 +99,7 @@ function modelSupportsImages(
9799
const PARALLEL_PROVIDERS: ModelProvider[] = [
98100
"anthropic",
99101
"bedrock",
102+
"sagemaker",
100103
"deepinfra",
101104
"gemini",
102105
"huggingface-inference-api",

core/llm/llms/SageMaker.ts

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import {
2+
SageMakerRuntimeClient,
3+
InvokeEndpointWithResponseStreamCommand
4+
} from "@aws-sdk/client-sagemaker-runtime";
5+
import { fromIni } from "@aws-sdk/credential-providers";
6+
7+
const jinja = require("jinja-js");
8+
9+
import {
10+
ChatMessage,
11+
CompletionOptions,
12+
LLMOptions,
13+
MessageContent,
14+
ModelProvider,
15+
} from "../../index.js";
16+
import { BaseLLM } from "../index.js";
17+
18+
class SageMaker extends BaseLLM {
19+
private static PROFILE_NAME: string = "sagemaker";
20+
static providerName: ModelProvider = "sagemaker";
21+
static defaultOptions: Partial<LLMOptions> = {
22+
region: "us-west-2",
23+
contextLength: 200_000,
24+
};
25+
26+
constructor(options: LLMOptions) {
27+
super(options);
28+
if (!options.apiBase) {
29+
this.apiBase = `https://runtime.sagemaker.${options.region}.amazonaws.com`;
30+
}
31+
}
32+
33+
protected async *_streamComplete(
34+
prompt: string,
35+
options: CompletionOptions,
36+
): AsyncGenerator<string> {
37+
const credentials = await this._getCredentials();
38+
const client = new SageMakerRuntimeClient({
39+
region: this.region,
40+
credentials: {
41+
accessKeyId: credentials.accessKeyId,
42+
secretAccessKey: credentials.secretAccessKey,
43+
sessionToken: credentials.sessionToken || "",
44+
},
45+
});
46+
const toolkit = new CompletionAPIToolkit(this);
47+
const command = toolkit.generateCommand([], prompt, options);
48+
const response = await client.send(command);
49+
if (response.Body) {
50+
for await (const value of response.Body) {
51+
const text = toolkit.unwrapResponseChunk(value);
52+
if (text) {
53+
yield text;
54+
}
55+
}
56+
}
57+
}
58+
59+
protected async *_streamChat(
60+
messages: ChatMessage[],
61+
options: CompletionOptions,
62+
): AsyncGenerator<ChatMessage> {
63+
const credentials = await this._getCredentials();
64+
const client = new SageMakerRuntimeClient({
65+
region: this.region,
66+
credentials: {
67+
accessKeyId: credentials.accessKeyId,
68+
secretAccessKey: credentials.secretAccessKey,
69+
sessionToken: credentials.sessionToken || "",
70+
},
71+
});
72+
const toolkit = new MessageAPIToolkit(this);
73+
74+
const command = toolkit.generateCommand(messages, "", options);
75+
const response = await client.send(command);
76+
if (response.Body) {
77+
for await (const value of response.Body) {
78+
const text = toolkit.unwrapResponseChunk(value);
79+
if (text) {
80+
yield { role: "assistant", content: text };
81+
}
82+
}
83+
}
84+
}
85+
86+
private async _getCredentials() {
87+
try {
88+
return await fromIni({
89+
profile: SageMaker.PROFILE_NAME,
90+
})();
91+
} catch (e) {
92+
console.warn(
93+
`AWS profile with name ${SageMaker.PROFILE_NAME} not found in ~/.aws/credentials, using default profile`,
94+
);
95+
return await fromIni()();
96+
}
97+
}
98+
99+
}
100+
101+
interface SageMakerModelToolkit {
102+
generateCommand(
103+
messages: ChatMessage[],
104+
prompt: string,
105+
options: CompletionOptions,
106+
): InvokeEndpointWithResponseStreamCommand;
107+
unwrapResponseChunk(rawValue: any): string;
108+
}
109+
110+
class MessageAPIToolkit implements SageMakerModelToolkit {
111+
constructor(private sagemaker: SageMaker) {}
112+
generateCommand(
113+
messages: ChatMessage[],
114+
prompt: string,
115+
options: CompletionOptions,
116+
): InvokeEndpointWithResponseStreamCommand {
117+
118+
if ("chat_template" in this.sagemaker.completionOptions) {
119+
// for some model you can apply chat_template to the model
120+
let prompt = jinja.compile(this.sagemaker.completionOptions.chat_template).render(
121+
{messages: messages, add_generation_prompt: true},
122+
{autoEscape: false}
123+
)
124+
const payload = {
125+
inputs: prompt,
126+
parameters: this.sagemaker.completionOptions,
127+
stream: true,
128+
};
129+
130+
return new InvokeEndpointWithResponseStreamCommand({
131+
EndpointName: options.model,
132+
Body: new TextEncoder().encode(JSON.stringify(payload)),
133+
ContentType: "application/json",
134+
CustomAttributes: "accept_eula=false",
135+
});
136+
}
137+
else {
138+
const payload = {
139+
messages: messages,
140+
max_tokens: options.maxTokens,
141+
temperature: options.temperature,
142+
top_p: options.topP,
143+
stream: "true",
144+
};
145+
146+
return new InvokeEndpointWithResponseStreamCommand({
147+
EndpointName: options.model,
148+
Body: new TextEncoder().encode(JSON.stringify(payload)),
149+
ContentType: "application/json",
150+
CustomAttributes: "accept_eula=false",
151+
});
152+
}
153+
154+
}
155+
unwrapResponseChunk(rawValue: any): string {
156+
const binaryChunk = rawValue.PayloadPart?.Bytes;
157+
const textChunk = new TextDecoder().decode(binaryChunk);
158+
try {
159+
const chunk = JSON.parse(textChunk)
160+
if ("choices" in chunk) {
161+
return chunk.choices[0].delta.content;
162+
}
163+
else if ("token" in chunk) {
164+
return chunk.token.text;
165+
}
166+
else {
167+
return "";
168+
}
169+
} catch (error) {
170+
console.error(textChunk);
171+
console.error(error);
172+
return "";
173+
}
174+
}
175+
}
176+
class CompletionAPIToolkit implements SageMakerModelToolkit {
177+
constructor(private sagemaker: SageMaker) {}
178+
generateCommand(
179+
messages: ChatMessage[],
180+
prompt: string,
181+
options: CompletionOptions,
182+
): InvokeEndpointWithResponseStreamCommand {
183+
const payload = {
184+
inputs: prompt,
185+
parameters: this.sagemaker.completionOptions,
186+
stream: true,
187+
};
188+
189+
return new InvokeEndpointWithResponseStreamCommand({
190+
EndpointName: options.model,
191+
Body: new TextEncoder().encode(JSON.stringify(payload)),
192+
ContentType: "application/json",
193+
CustomAttributes: "accept_eula=false",
194+
});
195+
}
196+
unwrapResponseChunk(rawValue: any): string {
197+
const binaryChunk = rawValue.PayloadPart?.Bytes;
198+
const textChunk = new TextDecoder().decode(binaryChunk);
199+
try {
200+
return JSON.parse(textChunk).token.text;
201+
} catch (error) {
202+
console.error(textChunk);
203+
console.error(error);
204+
return "";
205+
}
206+
}
207+
}
208+
209+
210+
export default SageMaker;

core/llm/llms/index.ts

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import TextGenWebUI from "./TextGenWebUI";
3333
import Together from "./Together";
3434
import ContinueProxy from "./stubs/ContinueProxy";
3535
import WatsonX from "./WatsonX";
36+
import SageMaker from "./SageMaker";
3637
import { renderTemplatedString } from "../../promptFiles/renderTemplatedString";
3738

3839
const LLMs = [
@@ -52,6 +53,7 @@ const LLMs = [
5253
LMStudio,
5354
Mistral,
5455
Bedrock,
56+
SageMaker,
5557
DeepInfra,
5658
Flowise,
5759
Groq,

core/package-lock.json

+10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

core/package.json

+2
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@
4040
"dependencies": {
4141
"@aws-sdk/client-bedrock-runtime": "^3.620.1",
4242
"@aws-sdk/credential-providers": "^3.620.1",
43+
"@aws-sdk/client-sagemaker-runtime": "^3.621.0",
4344
"@continuedev/config-types": "^1.0.10",
4445
"@continuedev/llm-info": "^1.0.1",
4546
"@mozilla/readability": "^0.5.0",
4647
"@octokit/rest": "^20.0.2",
4748
"@typescript-eslint/eslint-plugin": "^7.8.0",
4849
"@typescript-eslint/parser": "^7.8.0",
4950
"@xenova/transformers": "2.14.0",
51+
"jinja-js": "0.1.8",
5052
"adf-to-md": "^1.1.0",
5153
"async-mutex": "^0.5.0",
5254
"axios": "^1.6.7",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# AWS SageMaker
2+
3+
SageMaker provider support SageMaker endpoint deployed with [LMI](https://docs.djl.ai/docs/serving/serving/docs/lmi/index.html)
4+
5+
To setup SageMaker, add the following to your `config.json` file:
6+
7+
```json title="~/.continue/config.json"
8+
{
9+
"models": [
10+
{
11+
"title": "deepseek-6.7b-instruct",
12+
"provider": "sagemaker",
13+
"model": "lmi-model-deepseek-coder-xxxxxxx",
14+
"region": "us-west-2"
15+
},
16+
]
17+
}
18+
```
19+
20+
The value in model should be the SageMaker endpoint name you deployed.
21+
22+
Authentication will be through temporary or long-term credentials in
23+
~/.aws/credentials under a profile called "sagemaker".
24+
25+
```title="~/.aws/credentials
26+
[sagemaker]
27+
aws_access_key_id = abcdefg
28+
aws_secret_access_key = hijklmno
29+
aws_session_token = pqrstuvwxyz # Optional: means short term creds.
30+
```

docs/docs/setup/model-providers.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ keywords:
1818
Ollama,
1919
HuggingFace,
2020
AWS Bedrock,
21+
AWS SageMaker,
2122
]
2223
---
2324

0 commit comments

Comments
 (0)