Skip to content

add sagemaker provider #1917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions binary/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/config/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ declare global {
| "gemini"
| "mistral"
| "bedrock"
| "sagemaker"
| "deepinfra"
| "flowise"
| "groq"
Expand Down
1 change: 1 addition & 0 deletions core/control-plane/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const modelDescriptionSchema = z.object({
"gemini",
"mistral",
"bedrock",
"sagemaker",
"cloudflare",
"azure",
]),
Expand Down
1 change: 1 addition & 0 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ type ModelProvider =
| "gemini"
| "mistral"
| "bedrock"
| "sagemaker"
| "deepinfra"
| "flowise"
| "groq"
Expand Down
3 changes: 3 additions & 0 deletions core/llm/autodetect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const PROVIDER_HANDLES_TEMPLATING: ModelProvider[] = [
"msty",
"anthropic",
"bedrock",
"sagemaker",
"continue-proxy",
"mistral",
];
Expand All @@ -54,6 +55,7 @@ const PROVIDER_SUPPORTS_IMAGES: ModelProvider[] = [
"msty",
"anthropic",
"bedrock",
"sagemaker",
"continue-proxy",
];

Expand Down Expand Up @@ -97,6 +99,7 @@ function modelSupportsImages(
const PARALLEL_PROVIDERS: ModelProvider[] = [
"anthropic",
"bedrock",
"sagemaker",
"deepinfra",
"gemini",
"huggingface-inference-api",
Expand Down
210 changes: 210 additions & 0 deletions core/llm/llms/SageMaker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import {
SageMakerRuntimeClient,
InvokeEndpointWithResponseStreamCommand
} from "@aws-sdk/client-sagemaker-runtime";
import { fromIni } from "@aws-sdk/credential-providers";

const jinja = require("jinja-js");

import {
ChatMessage,
CompletionOptions,
LLMOptions,
MessageContent,
ModelProvider,
} from "../../index.js";
import { BaseLLM } from "../index.js";

class SageMaker extends BaseLLM {
private static PROFILE_NAME: string = "sagemaker";
static providerName: ModelProvider = "sagemaker";
static defaultOptions: Partial<LLMOptions> = {
region: "us-west-2",
contextLength: 200_000,
};

constructor(options: LLMOptions) {
super(options);
if (!options.apiBase) {
this.apiBase = `https://runtime.sagemaker.${options.region}.amazonaws.com`;
}
}

protected async *_streamComplete(
prompt: string,
options: CompletionOptions,
): AsyncGenerator<string> {
const credentials = await this._getCredentials();
const client = new SageMakerRuntimeClient({
region: this.region,
credentials: {
accessKeyId: credentials.accessKeyId,
secretAccessKey: credentials.secretAccessKey,
sessionToken: credentials.sessionToken || "",
},
});
const toolkit = new CompletionAPIToolkit(this);
const command = toolkit.generateCommand([], prompt, options);
const response = await client.send(command);
if (response.Body) {
for await (const value of response.Body) {
const text = toolkit.unwrapResponseChunk(value);
if (text) {
yield text;
}
}
}
}

protected async *_streamChat(
messages: ChatMessage[],
options: CompletionOptions,
): AsyncGenerator<ChatMessage> {
const credentials = await this._getCredentials();
const client = new SageMakerRuntimeClient({
region: this.region,
credentials: {
accessKeyId: credentials.accessKeyId,
secretAccessKey: credentials.secretAccessKey,
sessionToken: credentials.sessionToken || "",
},
});
const toolkit = new MessageAPIToolkit(this);

const command = toolkit.generateCommand(messages, "", options);
const response = await client.send(command);
if (response.Body) {
for await (const value of response.Body) {
const text = toolkit.unwrapResponseChunk(value);
if (text) {
yield { role: "assistant", content: text };
}
}
}
}

private async _getCredentials() {
try {
return await fromIni({
profile: SageMaker.PROFILE_NAME,
})();
} catch (e) {
console.warn(
`AWS profile with name ${SageMaker.PROFILE_NAME} not found in ~/.aws/credentials, using default profile`,
);
return await fromIni()();
}
}

}

interface SageMakerModelToolkit {
generateCommand(
messages: ChatMessage[],
prompt: string,
options: CompletionOptions,
): InvokeEndpointWithResponseStreamCommand;
unwrapResponseChunk(rawValue: any): string;
}

class MessageAPIToolkit implements SageMakerModelToolkit {
constructor(private sagemaker: SageMaker) {}
generateCommand(
messages: ChatMessage[],
prompt: string,
options: CompletionOptions,
): InvokeEndpointWithResponseStreamCommand {

if ("chat_template" in this.sagemaker.completionOptions) {
// for some model you can apply chat_template to the model
let prompt = jinja.compile(this.sagemaker.completionOptions.chat_template).render(
{messages: messages, add_generation_prompt: true},
{autoEscape: false}
)
const payload = {
inputs: prompt,
parameters: this.sagemaker.completionOptions,
stream: true,
};

return new InvokeEndpointWithResponseStreamCommand({
EndpointName: options.model,
Body: new TextEncoder().encode(JSON.stringify(payload)),
ContentType: "application/json",
CustomAttributes: "accept_eula=false",
});
}
else {
const payload = {
messages: messages,
max_tokens: options.maxTokens,
temperature: options.temperature,
top_p: options.topP,
stream: "true",
};

return new InvokeEndpointWithResponseStreamCommand({
EndpointName: options.model,
Body: new TextEncoder().encode(JSON.stringify(payload)),
ContentType: "application/json",
CustomAttributes: "accept_eula=false",
});
}

}
unwrapResponseChunk(rawValue: any): string {
const binaryChunk = rawValue.PayloadPart?.Bytes;
const textChunk = new TextDecoder().decode(binaryChunk);
try {
const chunk = JSON.parse(textChunk)
if ("choices" in chunk) {
return chunk.choices[0].delta.content;
}
else if ("token" in chunk) {
return chunk.token.text;
}
else {
return "";
}
} catch (error) {
console.error(textChunk);
console.error(error);
return "";
}
}
}
class CompletionAPIToolkit implements SageMakerModelToolkit {
constructor(private sagemaker: SageMaker) {}
generateCommand(
messages: ChatMessage[],
prompt: string,
options: CompletionOptions,
): InvokeEndpointWithResponseStreamCommand {
const payload = {
inputs: prompt,
parameters: this.sagemaker.completionOptions,
stream: true,
};

return new InvokeEndpointWithResponseStreamCommand({
EndpointName: options.model,
Body: new TextEncoder().encode(JSON.stringify(payload)),
ContentType: "application/json",
CustomAttributes: "accept_eula=false",
});
}
unwrapResponseChunk(rawValue: any): string {
const binaryChunk = rawValue.PayloadPart?.Bytes;
const textChunk = new TextDecoder().decode(binaryChunk);
try {
return JSON.parse(textChunk).token.text;
} catch (error) {
console.error(textChunk);
console.error(error);
return "";
}
}
}


export default SageMaker;
2 changes: 2 additions & 0 deletions core/llm/llms/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import TextGenWebUI from "./TextGenWebUI";
import Together from "./Together";
import ContinueProxy from "./stubs/ContinueProxy";
import WatsonX from "./WatsonX";
import SageMaker from "./SageMaker";
import { renderTemplatedString } from "../../promptFiles/renderTemplatedString";

const LLMs = [
Expand All @@ -52,6 +53,7 @@ const LLMs = [
LMStudio,
Mistral,
Bedrock,
SageMaker,
DeepInfra,
Flowise,
Groq,
Expand Down
10 changes: 10 additions & 0 deletions core/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@
"dependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.620.1",
"@aws-sdk/credential-providers": "^3.620.1",
"@aws-sdk/client-sagemaker-runtime": "^3.621.0",
"@continuedev/config-types": "^1.0.10",
"@continuedev/llm-info": "^1.0.1",
"@mozilla/readability": "^0.5.0",
"@octokit/rest": "^20.0.2",
"@typescript-eslint/eslint-plugin": "^7.8.0",
"@typescript-eslint/parser": "^7.8.0",
"@xenova/transformers": "2.14.0",
"jinja-js": "0.1.8",
"adf-to-md": "^1.1.0",
"async-mutex": "^0.5.0",
"axios": "^1.6.7",
Expand Down
30 changes: 30 additions & 0 deletions docs/docs/reference/Model Providers/sagemaker.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# AWS SageMaker

SageMaker provider support SageMaker endpoint deployed with [LMI](https://docs.djl.ai/docs/serving/serving/docs/lmi/index.html)

To setup SageMaker, add the following to your `config.json` file:

```json title="~/.continue/config.json"
{
"models": [
{
"title": "deepseek-6.7b-instruct",
"provider": "sagemaker",
"model": "lmi-model-deepseek-coder-xxxxxxx",
"region": "us-west-2"
},
]
}
```

The value in model should be the SageMaker endpoint name you deployed.

Authentication will be through temporary or long-term credentials in
~/.aws/credentials under a profile called "sagemaker".

```title="~/.aws/credentials
[sagemaker]
aws_access_key_id = abcdefg
aws_secret_access_key = hijklmno
aws_session_token = pqrstuvwxyz # Optional: means short term creds.
```
1 change: 1 addition & 0 deletions docs/docs/setup/model-providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ keywords:
Ollama,
HuggingFace,
AWS Bedrock,
AWS SageMaker,
]
---

Expand Down
Loading
Loading