Skip to content

Commit ac9e9cb

Browse files
sestinjPatrick-Erichsen
authored andcommitted
OpenAI Adapters (continuedev#1859)
* embeddings support in openai-adapters * rerank support in openai-adapters * cohere embed/rerank * continue-proxy reranker and embeddings providers * test for openai-adapters * embeddings provider and reranker for continue proxy
1 parent b7dc2bb commit ac9e9cb

33 files changed

+4872
-307
lines changed

.vscode/launch.json

+23-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@
6868
"${workspaceFolder}/extensions/vscode/out/extension.js",
6969
"/Users/natesesti/.continue/config.ts"
7070
],
71-
"preLaunchTask": "vscode-extension:build"
71+
"preLaunchTask": "vscode-extension:build",
72+
"env": {
73+
"CONTROL_PLANE_ENV": "local"
74+
}
7275
},
7376
// Has to be run after starting the server (separately or using the compound configuration)
7477
{
@@ -112,5 +115,24 @@
112115
"console": "integratedTerminal",
113116
"internalConsoleOptions": "neverOpen"
114117
}
118+
// {
119+
// "name": "[openai-adapters] Jest Test Debugger, Current Open File",
120+
// "type": "node",
121+
// "request": "launch",
122+
// "runtimeArgs": [
123+
// "--inspect-brk",
124+
// "${workspaceRoot}/packages/openai-adapters/node_modules/jest/bin/jest.js",
125+
// "--runInBand",
126+
// "--config",
127+
// "${workspaceRoot}/packages/openai-adapters/jest.config.mjs",
128+
// "${relativeFile}"
129+
// ],
130+
// "cwd": "${workspaceRoot}/packages/openai-adapters",
131+
// "console": "integratedTerminal",
132+
// "internalConsoleOptions": "neverOpen",
133+
// "env": {
134+
// "NODE_OPTIONS": "--experimental-vm-modules"
135+
// }
136+
// }
115137
]
116138
}

core/config/profile/doLoadConfig.ts

+13
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import {
44
IdeSettings,
55
SerializedContinueConfig,
66
} from "../..";
7+
import { ContinueProxyReranker } from "../../context/rerankers/ContinueProxyReranker";
78
import { ControlPlaneClient } from "../../control-plane/client";
89
import { TeamAnalytics } from "../../control-plane/TeamAnalytics";
10+
import ContinueProxyEmbeddingsProvider from "../../indexing/embeddings/ContinueProxyEmbeddingsProvider";
911
import ContinueProxy from "../../llm/llms/stubs/ContinueProxy";
1012
import { Telemetry } from "../../util/posthog";
1113
import { loadFullConfigNode } from "../load";
@@ -65,5 +67,16 @@ export default async function doLoadConfig(
6567
},
6668
);
6769

70+
if (newConfig.embeddingsProvider?.providerName === "continue-proxy") {
71+
(
72+
newConfig.embeddingsProvider as ContinueProxyEmbeddingsProvider
73+
).workOsAccessToken = workOsAccessToken;
74+
}
75+
76+
if (newConfig.reranker?.name === "continue-proxy") {
77+
(newConfig.reranker as ContinueProxyReranker).workOsAccessToken =
78+
workOsAccessToken;
79+
}
80+
6881
return newConfig;
6982
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import fetch from "node-fetch";
2+
import { CONTROL_PLANE_URL } from "../../control-plane/client.js";
3+
import { Chunk, Reranker } from "../../index.js";
4+
5+
export class ContinueProxyReranker implements Reranker {
6+
name = "continue-proxy";
7+
8+
private _workOsAccessToken: string | undefined = undefined;
9+
10+
get workOsAccessToken(): string | undefined {
11+
return this._workOsAccessToken;
12+
}
13+
14+
set workOsAccessToken(value: string | undefined) {
15+
if (this._workOsAccessToken !== value) {
16+
this._workOsAccessToken = value;
17+
this.params.apiKey = value!;
18+
}
19+
}
20+
21+
constructor(
22+
private readonly params: {
23+
apiKey: string;
24+
model?: string;
25+
},
26+
) {}
27+
28+
async rerank(query: string, chunks: Chunk[]): Promise<number[]> {
29+
const url = new URL("/model-proxy/v1/rerank", CONTROL_PLANE_URL);
30+
const resp = await fetch(url, {
31+
method: "POST",
32+
headers: {
33+
"Content-Type": "application/json",
34+
Authorization: `Bearer ${this.params.apiKey}`,
35+
},
36+
body: JSON.stringify({
37+
query,
38+
documents: chunks.map((chunk) => chunk.content),
39+
model: this.params.model,
40+
}),
41+
});
42+
const data: any = await resp.json();
43+
const results = data.data.sort((a: any, b: any) => a.index - b.index);
44+
return results.map((result: any) => result.relevance_score);
45+
}
46+
}

core/context/rerankers/index.ts

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import { RerankerName } from "../../index.js";
22
import { CohereReranker } from "./cohere.js";
3+
import { ContinueProxyReranker } from "./ContinueProxyReranker.js";
34
import { FreeTrialReranker } from "./freeTrial.js";
45
import { LLMReranker } from "./llm.js";
6+
import { HuggingFaceTEIReranker } from "./tei.js";
57
import { VoyageReranker } from "./voyage.js";
6-
import {HuggingFaceTEIReranker} from "./tei.js"
78

89
export const AllRerankers: { [key in RerankerName]: any } = {
910
cohere: CohereReranker,
1011
llm: LLMReranker,
1112
voyage: VoyageReranker,
1213
"free-trial": FreeTrialReranker,
13-
"huggingface-tei": HuggingFaceTEIReranker
14+
"huggingface-tei": HuggingFaceTEIReranker,
15+
"continue-proxy": ContinueProxyReranker,
1416
};

core/control-plane/client.ts

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ export interface ControlPlaneWorkspace {
1818

1919
export interface ControlPlaneModelDescription extends ModelDescription {}
2020

21-
// export const CONTROL_PLANE_URL = "http://localhost:3001";
2221
export const CONTROL_PLANE_URL =
23-
"https://control-plane-api-service-i3dqylpbqa-uc.a.run.app";
22+
process.env.CONTROL_PLANE_ENV === "local"
23+
? "http://localhost:3001"
24+
: "https://control-plane-api-service-i3dqylpbqa-uc.a.run.app";
2425

2526
export class ControlPlaneClient {
2627
private static URL = CONTROL_PLANE_URL;

core/index.d.ts

+6-2
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,9 @@ export type EmbeddingsProviderName =
739739
| "openai"
740740
| "cohere"
741741
| "free-trial"
742-
| "gemini";
742+
| "gemini"
743+
| "continue-proxy"
744+
| "deepinfra";
743745

744746
export interface EmbedOptions {
745747
apiBase?: string;
@@ -758,6 +760,7 @@ export interface EmbeddingsProviderDescription extends EmbedOptions {
758760

759761
export interface EmbeddingsProvider {
760762
id: string;
763+
providerName: EmbeddingsProviderName;
761764
maxChunkSize: number;
762765
embed(chunks: string[]): Promise<number[][]>;
763766
}
@@ -767,7 +770,8 @@ export type RerankerName =
767770
| "voyage"
768771
| "llm"
769772
| "free-trial"
770-
| "huggingface-tei";
773+
| "huggingface-tei"
774+
| "continue-proxy";
771775

772776
export interface RerankerDescription {
773777
name: RerankerName;

core/indexing/embeddings/BaseEmbeddingsProvider.ts

+8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import {
22
EmbedOptions,
33
EmbeddingsProvider,
4+
EmbeddingsProviderName,
45
FetchFunction,
56
} from "../../index.js";
67

@@ -17,6 +18,11 @@ abstract class BaseEmbeddingsProvider implements IBaseEmbeddingsProvider {
1718
static maxBatchSize: IBaseEmbeddingsProvider["maxBatchSize"];
1819
static defaultOptions: IBaseEmbeddingsProvider["defaultOptions"];
1920

21+
static providerName: EmbeddingsProviderName;
22+
get providerName(): EmbeddingsProviderName {
23+
return (this.constructor as typeof BaseEmbeddingsProvider).providerName;
24+
}
25+
2026
options: IBaseEmbeddingsProvider["options"];
2127
fetch: IBaseEmbeddingsProvider["fetch"];
2228
id: IBaseEmbeddingsProvider["id"];
@@ -38,6 +44,8 @@ abstract class BaseEmbeddingsProvider implements IBaseEmbeddingsProvider {
3844
this.id = `${this.constructor.name}::${this.options.model}`;
3945
}
4046
}
47+
defaultOptions?: EmbedOptions | undefined;
48+
maxBatchSize?: number | undefined;
4149

4250
abstract embed(chunks: string[]): Promise<number[][]>;
4351

core/indexing/embeddings/CohereEmbeddingsProvider.ts

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import { Response } from "node-fetch";
2-
import { EmbedOptions } from "../../index.js";
2+
import { EmbeddingsProviderName, EmbedOptions } from "../../index.js";
33
import { withExponentialBackoff } from "../../util/withExponentialBackoff.js";
44
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";
55

66
class CohereEmbeddingsProvider extends BaseEmbeddingsProvider {
77
static maxBatchSize = 96;
88

9+
static providerName: EmbeddingsProviderName = "cohere";
10+
911
static defaultOptions: Partial<EmbedOptions> | undefined = {
1012
apiBase: "https://api.cohere.ai/v1/",
1113
model: "embed-english-v3.0",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { EmbeddingsProviderName, EmbedOptions } from "../..";
2+
import { CONTROL_PLANE_URL } from "../../control-plane/client";
3+
import OpenAIEmbeddingsProvider from "./OpenAIEmbeddingsProvider";
4+
5+
class ContinueProxyEmbeddingsProvider extends OpenAIEmbeddingsProvider {
6+
static providerName: EmbeddingsProviderName = "continue-proxy";
7+
static defaultOptions: Partial<EmbedOptions> | undefined = {
8+
apiBase: new URL("/model-proxy/v1", CONTROL_PLANE_URL).toString(),
9+
};
10+
11+
private _workOsAccessToken: string | undefined = undefined;
12+
13+
get workOsAccessToken(): string | undefined {
14+
return this._workOsAccessToken;
15+
}
16+
17+
set workOsAccessToken(value: string | undefined) {
18+
if (this._workOsAccessToken !== value) {
19+
this._workOsAccessToken = value;
20+
this.options.apiKey = value;
21+
}
22+
}
23+
}
24+
25+
export default ContinueProxyEmbeddingsProvider;

core/indexing/embeddings/DeepInfraEmbeddingsProvider.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import { EmbedOptions } from "../../index.js";
1+
import { EmbeddingsProviderName, EmbedOptions } from "../../index.js";
22
import { withExponentialBackoff } from "../../util/withExponentialBackoff.js";
33
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";
44

55
class DeepInfraEmbeddingsProvider extends BaseEmbeddingsProvider {
6+
static providerName: EmbeddingsProviderName = "deepinfra";
67
static defaultOptions: Partial<EmbedOptions> | undefined = {
78
model: "sentence-transformers/all-MiniLM-L6-v2",
89
};

core/indexing/embeddings/FreeTrialEmbeddingsProvider.ts

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import { Response } from "node-fetch";
22
import { getHeaders } from "../../continueServer/stubs/headers.js";
33
import { constants } from "../../deploy/constants.js";
4-
import { EmbedOptions, FetchFunction } from "../../index.js";
4+
import {
5+
EmbeddingsProviderName,
6+
EmbedOptions,
7+
FetchFunction,
8+
} from "../../index.js";
59
import { withExponentialBackoff } from "../../util/withExponentialBackoff.js";
610
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";
711

812
class FreeTrialEmbeddingsProvider extends BaseEmbeddingsProvider {
13+
static providerName: EmbeddingsProviderName = "free-trial";
914
static maxBatchSize = 128;
1015

1116
static defaultOptions: Partial<EmbedOptions> | undefined = {

core/indexing/embeddings/GeminiEmbeddingsProvider.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
import { Response } from "node-fetch";
2-
import { withExponentialBackoff } from "../../util/withExponentialBackoff.js";
3-
import BaseEmbeddingsProvider, {
4-
IBaseEmbeddingsProvider,
5-
} from "./BaseEmbeddingsProvider.js";
61
import {
72
EmbedContentRequest,
83
EmbedContentResponse,
94
} from "@google/generative-ai";
5+
import { Response } from "node-fetch";
6+
import { EmbeddingsProviderName } from "../../index.js";
7+
import { withExponentialBackoff } from "../../util/withExponentialBackoff.js";
8+
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";
109

1110
/**
1211
* [View the Gemini Text Embedding docs.](https://ai.google.dev/gemini-api/docs/models/gemini#text-embedding-and-embedding)
1312
*/
1413
class GeminiEmbeddingsProvider extends BaseEmbeddingsProvider {
14+
static providerName: EmbeddingsProviderName = "gemini";
1515
static maxBatchSize = 2048;
1616

1717
static defaultOptions = {

core/indexing/embeddings/HuggingFaceTEIEmbeddingsProvider.ts

+14-11
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
import fetch, { Response } from "node-fetch";
2-
import { EmbedOptions, FetchFunction } from "../..";
1+
import { Response } from "node-fetch";
2+
import { EmbeddingsProviderName, EmbedOptions, FetchFunction } from "../..";
33
import { withExponentialBackoff } from "../../util/withExponentialBackoff";
44
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider";
55

66
class HuggingFaceTEIEmbeddingsProvider extends BaseEmbeddingsProvider {
7-
private maxBatchSize = 32;
7+
static providerName: EmbeddingsProviderName = "huggingface-tei";
8+
maxBatchSize = 32;
89

910
static defaultOptions: Partial<EmbedOptions> | undefined = {
1011
apiBase: "http://localhost:8080",
@@ -17,7 +18,7 @@ class HuggingFaceTEIEmbeddingsProvider extends BaseEmbeddingsProvider {
1718
if (!this.options.apiBase?.endsWith("/")) {
1819
this.options.apiBase += "/";
1920
}
20-
this.doInfoRequest().then(response => {
21+
this.doInfoRequest().then((response) => {
2122
this.options.model = response.model_id;
2223
this.maxBatchSize = response.max_client_batch_size;
2324
});
@@ -26,7 +27,9 @@ class HuggingFaceTEIEmbeddingsProvider extends BaseEmbeddingsProvider {
2627
async embed(chunks: string[]) {
2728
const promises = [];
2829
for (let i = 0; i < chunks.length; i += this.maxBatchSize) {
29-
promises.push(this.doEmbedRequest(chunks.slice(i, i + this.maxBatchSize)));
30+
promises.push(
31+
this.doEmbedRequest(chunks.slice(i, i + this.maxBatchSize)),
32+
);
3033
}
3134
const results = await Promise.all(promises);
3235
return results.flat();
@@ -37,11 +40,11 @@ class HuggingFaceTEIEmbeddingsProvider extends BaseEmbeddingsProvider {
3740
this.fetch(new URL("embed", this.options.apiBase), {
3841
method: "POST",
3942
body: JSON.stringify({
40-
inputs: batch
43+
inputs: batch,
4144
}),
4245
headers: {
4346
"Content-Type": "application/json",
44-
}
47+
},
4548
}),
4649
);
4750
if (!resp.ok) {
@@ -75,9 +78,9 @@ class TEIEmbedError extends Error {
7578
}
7679

7780
type TEIEmbedErrorResponse = {
78-
error: string
79-
error_type: string
80-
}
81+
error: string;
82+
error_type: string;
83+
};
8184

8285
type TEIInfoResponse = {
8386
model_id: string;
@@ -86,7 +89,7 @@ type TEIInfoResponse = {
8689
model_type: {
8790
embedding: {
8891
pooling: string;
89-
}
92+
};
9093
};
9194
max_concurrent_requests: number;
9295
max_input_length: number;

core/indexing/embeddings/OllamaEmbeddingsProvider.ts

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
import { EmbedOptions, FetchFunction } from "../../index.js";
1+
import {
2+
EmbeddingsProviderName,
3+
EmbedOptions,
4+
FetchFunction,
5+
} from "../../index.js";
26
import { withExponentialBackoff } from "../../util/withExponentialBackoff.js";
37
import BaseEmbeddingsProvider, {
48
IBaseEmbeddingsProvider,
@@ -41,6 +45,7 @@ async function embedOne(
4145
}
4246

4347
class OllamaEmbeddingsProvider extends BaseEmbeddingsProvider {
48+
static providerName: EmbeddingsProviderName = "ollama";
4449
static defaultOptions: IBaseEmbeddingsProvider["defaultOptions"] = {
4550
apiBase: "http://localhost:11434/",
4651
model: "nomic-embed-text",

core/indexing/embeddings/OpenAIEmbeddingsProvider.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import { Response } from "node-fetch";
2-
import { EmbedOptions } from "../../index.js";
2+
import { EmbeddingsProviderName, EmbedOptions } from "../../index.js";
33
import { withExponentialBackoff } from "../../util/withExponentialBackoff.js";
44
import BaseEmbeddingsProvider from "./BaseEmbeddingsProvider.js";
55

66
class OpenAIEmbeddingsProvider extends BaseEmbeddingsProvider {
7+
static providerName: EmbeddingsProviderName = "openai";
78
// https://platform.openai.com/docs/api-reference/embeddings/create is 2048
89
// but Voyage is 128
910
static maxBatchSize = 128;

0 commit comments

Comments
 (0)