Skip to content

Commit c3cbbd6

Browse files
committed
Merge branch 'dev' of https://github.com/continuedev/continue into dev
2 parents 1493192 + d901915 commit c3cbbd6

File tree

8 files changed

+322
-15
lines changed

8 files changed

+322
-15
lines changed

core/index.d.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,7 @@ export interface ModelDescription {
801801
}
802802

803803
export type EmbeddingsProviderName =
804+
| "sagemaker"
804805
| "bedrock"
805806
| "huggingface-tei"
806807
| "transformers.js"
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import {
2+
InvokeEndpointCommand,
3+
SageMakerRuntimeClient,
4+
} from "@aws-sdk/client-sagemaker-runtime";
5+
import { fromIni } from "@aws-sdk/credential-providers";
6+
7+
import {
8+
EmbeddingsProviderName,
9+
EmbedOptions,
10+
FetchFunction,
11+
} from "../../index.js";
12+
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";
13+
14+
15+
class SageMakerEmbeddingsProvider extends BaseEmbeddingsProvider {
16+
static providerName: EmbeddingsProviderName = "sagemaker";
17+
18+
static defaultOptions: Partial<EmbedOptions> | undefined = {
19+
region: "us-west-2",
20+
maxBatchSize: 1,
21+
};
22+
profile?: string | undefined;
23+
24+
constructor(options: EmbedOptions, fetch: FetchFunction) {
25+
super(options, fetch);
26+
if (!options.apiBase) {
27+
options.apiBase = `https://runtime.sagemaker.${options.region}.amazonaws.com`;
28+
}
29+
30+
if (options.profile) {
31+
this.profile = options.profile;
32+
} else {
33+
this.profile = "sagemaker";
34+
}
35+
}
36+
37+
async embed(chunks: string[]) {
38+
const credentials = await this._getCredentials();
39+
const client = new SageMakerRuntimeClient({
40+
region: this.options.region,
41+
credentials: {
42+
accessKeyId: credentials.accessKeyId,
43+
secretAccessKey: credentials.secretAccessKey,
44+
sessionToken: credentials.sessionToken || "",
45+
},
46+
});
47+
48+
const batchedChunks = this.getBatchedChunks(chunks);
49+
return (
50+
await Promise.all(
51+
batchedChunks.map(async (batch) => {
52+
const input = this._generateInvokeModelCommandInput(
53+
batch,
54+
this.options,
55+
);
56+
const command = new InvokeEndpointCommand(input);
57+
const response = await client.send(command);
58+
59+
if (response.Body) {
60+
const responseBody = JSON.parse(
61+
new TextDecoder().decode(response.Body),
62+
);
63+
// If the body contains a key called "embedding" or "embeddings", return the value, otherwise return the whole body
64+
if (responseBody.embedding) {
65+
return responseBody.embedding;
66+
} else if (responseBody.embeddings) {
67+
return responseBody.embeddings;
68+
} else {
69+
return responseBody;
70+
}
71+
}
72+
}),
73+
)
74+
).flat();
75+
}
76+
private _generateInvokeModelCommandInput(
77+
prompts: string | string[],
78+
options: EmbedOptions,
79+
): any {
80+
const payload = {
81+
inputs: prompts,
82+
normalize: true,
83+
// ...(options.requestOptions?.extraBodyProperties || {}),
84+
};
85+
86+
if (options.requestOptions?.extraBodyProperties) {
87+
Object.assign(payload, options.requestOptions.extraBodyProperties);
88+
}
89+
90+
return {
91+
EndpointName: this.options.model,
92+
Body: JSON.stringify(payload),
93+
ContentType: "application/json",
94+
CustomAttributes: "accept_eula=false",
95+
};
96+
}
97+
98+
private async _getCredentials() {
99+
try {
100+
return await fromIni({
101+
profile: this.profile,
102+
ignoreCache: true,
103+
})();
104+
} catch (e) {
105+
console.warn(
106+
`AWS profile with name ${this.profile} not found in ~/.aws/credentials, using default profile`,
107+
);
108+
return await fromIni()();
109+
}
110+
}
111+
}
112+
113+
export default SageMakerEmbeddingsProvider;
114+

core/indexing/embeddings/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { EmbeddingsProviderName } from "../../index.js";
22
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";
3+
import SageMakerEmbeddingsProvider from "./SageMakerEmbeddingsProvider.js";
34
import BedrockEmbeddingsProvider from "./BedrockEmbeddingsProvider.js";
45
import CohereEmbeddingsProvider from "./CohereEmbeddingsProvider.js";
56
import ContinueProxyEmbeddingsProvider from "./ContinueProxyEmbeddingsProvider.js";
@@ -22,6 +23,7 @@ export const allEmbeddingsProviders: Record<
2223
EmbeddingsProviderName,
2324
EmbeddingsProviderConstructor
2425
> = {
26+
sagemaker: SageMakerEmbeddingsProvider,
2527
bedrock: BedrockEmbeddingsProvider,
2628
ollama: OllamaEmbeddingsProvider,
2729
"transformers.js": TransformersJsEmbeddingsProvider,

docs/docs/customize/model-providers/more/sagemaker.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# AWS SageMaker
22

3-
SageMaker provider support SageMaker endpoint deployed with [LMI](https://docs.djl.ai/docs/serving/serving/docs/lmi/index.html)
3+
SageMaker can be used for both chat and embedding models. Chat models are supported for endpoints deployed with [LMI](https://docs.djl.ai/docs/serving/serving/docs/lmi/index.html), and embedding models are supported for endpoints deployed with [HuggingFace TEI](https://huggingface.co/blog/sagemaker-huggingface-embedding)
44

5-
To setup SageMaker, add the following to your `config.json` file:
5+
To setup SageMaker as a chat model provider, add the following to your `config.json` file:
66

77
```json title="config.json"
88
{
@@ -13,7 +13,11 @@ To setup SageMaker, add the following to your `config.json` file:
1313
"model": "lmi-model-deepseek-coder-xxxxxxx",
1414
"region": "us-west-2"
1515
}
16-
]
16+
],
17+
"embeddingsProvider": {
18+
"provider": "sagemaker",
19+
"model": "mxbai-embed-large-v1-endpoint"
20+
},
1721
}
1822
```
1923

docs/static/schemas/config.json

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,24 @@
246246
"region": {
247247
"title": "Region",
248248
"description": "The region where the model is hosted",
249-
"type": "string"
249+
"anyOf":[
250+
{
251+
"enum": [
252+
"us-east-1",
253+
"us-east-2",
254+
"us-west-1",
255+
"us-west-2",
256+
"eu-west-1",
257+
"eu-central-1",
258+
"ap-southeast-1",
259+
"ap-northeast-1",
260+
"ap-south-1"
261+
]
262+
},
263+
{
264+
"type": "string"
265+
}
266+
]
250267
},
251268
"profile": {
252269
"title": "Profile",
@@ -2150,7 +2167,8 @@
21502167
"gemini",
21512168
"voyage",
21522169
"nvidia",
2153-
"bedrock"
2170+
"bedrock",
2171+
"sagemaker"
21542172
]
21552173
},
21562174
"model": {
@@ -2184,7 +2202,24 @@
21842202
"region": {
21852203
"title": "Region",
21862204
"description": "The region where the model is hosted",
2187-
"type": "string"
2205+
"anyOf":[
2206+
{
2207+
"enum": [
2208+
"us-east-1",
2209+
"us-east-2",
2210+
"us-west-1",
2211+
"us-west-2",
2212+
"eu-west-1",
2213+
"eu-central-1",
2214+
"ap-southeast-1",
2215+
"ap-northeast-1",
2216+
"ap-south-1"
2217+
]
2218+
},
2219+
{
2220+
"type": "string"
2221+
}
2222+
]
21882223
},
21892224
"profile": {
21902225
"title": "Profile",
@@ -2206,6 +2241,24 @@
22062241
"then": {
22072242
"required": ["apiKey"]
22082243
}
2244+
},
2245+
{
2246+
"if": {
2247+
"properties": {
2248+
"provider": {
2249+
"enum": ["sagemaker"]
2250+
}
2251+
},
2252+
"required": ["provider"]
2253+
},
2254+
"then": {
2255+
"properties": {
2256+
"model": {
2257+
"markdownDescription": "SageMaker endpoint name"
2258+
}
2259+
},
2260+
"required": ["model"]
2261+
}
22092262
}
22102263
]
22112264
},

extensions/intellij/src/main/resources/config_schema.json

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,25 @@
246246
"region": {
247247
"title": "Region",
248248
"description": "The region where the model is hosted",
249-
"type": "string"
249+
"anyOf":[
250+
{
251+
"enum": [
252+
"us-east-1",
253+
"us-east-2",
254+
"us-west-1",
255+
"us-west-2",
256+
"eu-west-1",
257+
"eu-central-1",
258+
"ap-southeast-1",
259+
"ap-northeast-1",
260+
"ap-south-1"
261+
],
262+
"type" : "string"
263+
},
264+
{
265+
"type": "string"
266+
}
267+
]
250268
},
251269
"profile": {
252270
"title": "Profile",
@@ -2150,7 +2168,8 @@
21502168
"gemini",
21512169
"voyage",
21522170
"nvidia",
2153-
"bedrock"
2171+
"bedrock",
2172+
"sagemaker"
21542173
]
21552174
},
21562175
"model": {
@@ -2184,7 +2203,7 @@
21842203
"region": {
21852204
"title": "Region",
21862205
"description": "The region where the model is hosted",
2187-
"type": "string"
2206+
"$ref": "#/definitions/ModelDescription/properties/region"
21882207
},
21892208
"profile": {
21902209
"title": "Profile",
@@ -2206,6 +2225,24 @@
22062225
"then": {
22072226
"required": ["apiKey"]
22082227
}
2228+
},
2229+
{
2230+
"if": {
2231+
"properties": {
2232+
"provider": {
2233+
"enum": ["sagemaker"]
2234+
}
2235+
},
2236+
"required": ["provider"]
2237+
},
2238+
"then": {
2239+
"properties": {
2240+
"model": {
2241+
"markdownDescription": "SageMaker endpoint name"
2242+
}
2243+
},
2244+
"required": ["model"]
2245+
}
22092246
}
22102247
]
22112248
},

extensions/vscode/config_schema.json

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,25 @@
246246
"region": {
247247
"title": "Region",
248248
"description": "The region where the model is hosted",
249-
"type": "string"
249+
"anyOf":[
250+
{
251+
"enum": [
252+
"us-east-1",
253+
"us-east-2",
254+
"us-west-1",
255+
"us-west-2",
256+
"eu-west-1",
257+
"eu-central-1",
258+
"ap-southeast-1",
259+
"ap-northeast-1",
260+
"ap-south-1"
261+
],
262+
"type" : "string"
263+
},
264+
{
265+
"type": "string"
266+
}
267+
]
250268
},
251269
"profile": {
252270
"title": "Profile",
@@ -2150,7 +2168,8 @@
21502168
"gemini",
21512169
"voyage",
21522170
"nvidia",
2153-
"bedrock"
2171+
"bedrock",
2172+
"sagemaker"
21542173
]
21552174
},
21562175
"model": {
@@ -2184,7 +2203,7 @@
21842203
"region": {
21852204
"title": "Region",
21862205
"description": "The region where the model is hosted",
2187-
"type": "string"
2206+
"$ref": "#/definitions/ModelDescription/properties/region"
21882207
},
21892208
"profile": {
21902209
"title": "Profile",
@@ -2206,6 +2225,24 @@
22062225
"then": {
22072226
"required": ["apiKey"]
22082227
}
2228+
},
2229+
{
2230+
"if": {
2231+
"properties": {
2232+
"provider": {
2233+
"enum": ["sagemaker"]
2234+
}
2235+
},
2236+
"required": ["provider"]
2237+
},
2238+
"then": {
2239+
"properties": {
2240+
"model": {
2241+
"markdownDescription": "SageMaker endpoint name"
2242+
}
2243+
},
2244+
"required": ["model"]
2245+
}
22092246
}
22102247
]
22112248
},

0 commit comments

Comments
 (0)