Skip to content

add watsonx reranker integration #4419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 16, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions core/config/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,11 @@ async function intermediateToFinalConfig(
model: "rerank-2",
...params,
};
return new rerankerClass(llmOptions);
return new rerankerClass(llmOptions, (url: string | URL, init: any) =>
fetchwithRequestOptions(url, init, {
...params?.requestOptions,
}),
);
}
return null;
}
Expand Down Expand Up @@ -991,5 +995,6 @@ export {
finalToBrowserConfig,
intermediateToFinalConfig,
loadContinueConfigFromJson,
type BrowserSerializedContinueConfig,
type BrowserSerializedContinueConfig
};

2 changes: 2 additions & 0 deletions core/context/allRerankers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import HuggingFaceTEI from "../llm/llms/HuggingFaceTEI";
import { LLMReranker } from "../llm/llms/llm";
import ContinueProxy from "../llm/llms/stubs/ContinueProxy";
import Voyage from "../llm/llms/Voyage";
import WatsonX from "../llm/llms/WatsonX";

export const AllRerankers: { [key: string]: any } = {
cohere: Cohere,
bedrock: Bedrock,
llm: LLMReranker,
voyage: Voyage,
watsonx: WatsonX,
"free-trial": FreeTrial,
"huggingface-tei": HuggingFaceTEI,
"continue-proxy": ContinueProxy,
Expand Down
4 changes: 3 additions & 1 deletion core/control-plane/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const modelDescriptionSchema = z.object({
"nebius",
"siliconflow",
"scaleway",
"watsonx"
]),
model: z.string(),
apiKey: z.string().optional(),
Expand Down Expand Up @@ -88,6 +89,7 @@ const embeddingsProviderSchema = z.object({
"nebius",
"siliconflow",
"scaleway",
"watsonx"
]),
apiBase: z.string().optional(),
apiKey: z.string().optional(),
Expand All @@ -109,7 +111,7 @@ const embeddingsProviderSchema = z.object({
});

const rerankerSchema = z.object({
name: z.enum(["cohere", "voyage", "llm"]),
name: z.enum(["cohere", "voyage", "llm", "watsonx"]),
params: z.record(z.any()).optional(),
});

Expand Down
117 changes: 85 additions & 32 deletions core/llm/llms/WatsonX.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js";
import { ChatMessage, Chunk, CompletionOptions, LLMOptions } from "../../index.js";
import { renderChatMessage } from "../../util/messageContent.js";
import { BaseLLM } from "../index.js";
import { streamResponse } from "../stream.js";
Expand Down Expand Up @@ -138,6 +138,26 @@ class WatsonX extends BaseLLM {
};
}

protected async updateWatsonxToken() {
var now = new Date().getTime() / 1000;
if (
watsonxToken === undefined ||
now > watsonxToken.expiration ||
watsonxToken.token === undefined
) {
watsonxToken = await this.getBearerToken();
} else {
console.log(
`Reusing token (expires in ${
(watsonxToken.expiration - now) / 60
} mins)`,
);
}
if (watsonxToken.token === undefined) {
throw new Error("Something went wrong. Check your credentials, please.");
}
}

protected async _complete(
prompt: string,
signal: AbortSignal,
Expand Down Expand Up @@ -174,23 +194,8 @@ class WatsonX extends BaseLLM {
signal: AbortSignal,
options: CompletionOptions,
): AsyncGenerator<ChatMessage> {
var now = new Date().getTime() / 1000;
if (
watsonxToken === undefined ||
now > watsonxToken.expiration ||
watsonxToken.token === undefined
) {
watsonxToken = await this.getBearerToken();
} else {
console.log(
`Reusing token (expires in ${
(watsonxToken.expiration - now) / 60
} mins)`,
);
}
if (watsonxToken.token === undefined) {
throw new Error("Something went wrong. Check your credentials, please.");
}
await this.updateWatsonxToken();

const stopSequences =
options.stop?.slice(0, 6) ??
(options.model?.includes("granite") ? ["Question:"] : []);
Expand Down Expand Up @@ -267,20 +272,8 @@ class WatsonX extends BaseLLM {
}

protected async _embed(chunks: string[]): Promise<number[][]> {
var now = new Date().getTime() / 1000;
if (
watsonxToken === undefined ||
now > watsonxToken.expiration ||
watsonxToken.token === undefined
) {
watsonxToken = await this.getBearerToken();
} else {
console.log(
`Reusing token (expires in ${
(watsonxToken.expiration - now) / 60
} mins)`,
);
}
await this.updateWatsonxToken();

const payload: any = {
inputs: chunks,
parameters: {
Expand Down Expand Up @@ -318,6 +311,66 @@ class WatsonX extends BaseLLM {
}
return embeddings.map((e: any) => e.embedding);
}


async rerank(query: string, chunks: Chunk[]): Promise<number[]> {
if (!query || !chunks.length) {
throw new Error("Query and chunks must not be empty");
}
try {
await this.updateWatsonxToken();

const headers = {
"Content-Type": "application/json",
Authorization: `${
watsonxToken.expiration === -1 ? "ZenApiKey" : "Bearer"
} ${watsonxToken.token}`,
};

const payload: any = {
inputs: chunks.map((chunk) => ({ text: chunk.content })),
query: query,
parameters: {
truncate_input_tokens: 500,
return_options: {
top_n: chunks.length
},
},
model_id: this.model,
project_id: this.projectId,
};

const resp = await this.fetch(
new URL(
`${this.apiBase}/ml/v1/text/rerank?version=${this.apiVersion}`,
),
{
method: "POST",
headers: headers,
body: JSON.stringify(payload),
},
);

if (!resp.ok) {
throw new Error(`Failed to rerank chunks: ${await resp.text()}`);
}
const data = await resp.json();
const ranking = data.results;

if (!ranking) {
throw new Error("Empty response received from Watsonx");
}

// Sort results by index to maintain original order
return ranking
.sort((a: any, b: any) => a.index - b.index)
.map((result: any) => result.score);

} catch (error) {
console.error("Error in WatsonxReranker.rerank:", error);
throw error;
}
}
}

export default WatsonX;
34 changes: 34 additions & 0 deletions docs/docs/customize/model-providers/more/watsonx.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,37 @@ To view the list of available embeddings models, visit [this page](https://datap
```
</TabItem>
</Tabs>


## Reranker

<Tabs groupId="config-example">
<TabItem value="yaml" label="YAML">
```yaml title="config.yaml"
models:
- name: Watsonx Reranker
provider: watsonx
model: cross-encoder/ms-marco-minilm-l-12-v2
apiBase: https://us-south.ml.cloud.ibm.com
projectId: PROJECT_ID
apiKey: API_KEY/ZENAPI_KEY/USERNAME:PASSWORD
apiVersion: 2024-03-14
```
</TabItem>
<TabItem value="json" label="JSON">
```json title="config.json"
{
"reranker": {
"name": "watsonx",
"params": {
"model": "cross-encoder/ms-marco-minilm-l-12-v2",
"apiBase": "watsonx endpoint e.g. https://us-south.ml.cloud.ibm.com",
"projectId": "PROJECT_ID",
"apiKey": "API_KEY/ZENAPI_KEY/USERNAME:PASSWORD",
"apiVersion": "2024-03-14"
}
}
}
```
</TabItem>
</Tabs>
1 change: 1 addition & 0 deletions extensions/vscode/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2843,6 +2843,7 @@
"bedrock",
"cohere",
"voyage",
"watsonx",
"llm",
"free-trial",
"huggingface-tei"
Expand Down
4 changes: 2 additions & 2 deletions extensions/vscode/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion packages/config-types/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export const modelDescriptionSchema = z.object({
"continue-proxy",
"nebius",
"scaleway",
"watsonx"
]),
model: z.string(),
apiKey: z.string().optional(),
Expand Down Expand Up @@ -111,6 +112,7 @@ export const embeddingsProviderSchema = z.object({
"continue-proxy",
"nebius",
"scaleway",
"watsonx"
]),
apiBase: z.string().optional(),
apiKey: z.string().optional(),
Expand Down Expand Up @@ -173,7 +175,7 @@ export const contextProviderSchema = z.object({
export type ContextProvider = z.infer<typeof contextProviderSchema>;

export const rerankerSchema = z.object({
name: z.enum(["cohere", "voyage", "llm", "continue-proxy"]),
name: z.enum(["cohere", "voyage", "watsonx", "llm", "continue-proxy"]),
params: z.record(z.any()).optional(),
});
export type Reranker = z.infer<typeof rerankerSchema>;
Expand Down