Skip to content

Commit ac9c611

Browse files
authored
Merge pull request #4419 from mq200/feat/watsonx_reranker_integration
add watsonx reranker integration
2 parents 7bb5257 + e26d189 commit ac9c611

File tree

7 files changed

+135
-36
lines changed

7 files changed

+135
-36
lines changed

core/config/load.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,11 @@ async function intermediateToFinalConfig(
505505
model: "rerank-2",
506506
...params,
507507
};
508-
return new rerankerClass(llmOptions);
508+
return new rerankerClass(llmOptions, (url: string | URL, init: any) =>
509+
fetchwithRequestOptions(url, init, {
510+
...params?.requestOptions,
511+
}),
512+
);
509513
}
510514
return null;
511515
}
@@ -994,5 +998,6 @@ export {
994998
finalToBrowserConfig,
995999
intermediateToFinalConfig,
9961000
loadContinueConfigFromJson,
997-
type BrowserSerializedContinueConfig,
1001+
type BrowserSerializedContinueConfig
9981002
};
1003+

core/context/allRerankers.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ import HuggingFaceTEI from "../llm/llms/HuggingFaceTEI";
55
import { LLMReranker } from "../llm/llms/llm";
66
import ContinueProxy from "../llm/llms/stubs/ContinueProxy";
77
import Voyage from "../llm/llms/Voyage";
8+
import WatsonX from "../llm/llms/WatsonX";
89

910
export const AllRerankers: { [key: string]: any } = {
1011
cohere: Cohere,
1112
bedrock: Bedrock,
1213
llm: LLMReranker,
1314
voyage: Voyage,
15+
watsonx: WatsonX,
1416
"free-trial": FreeTrial,
1517
"huggingface-tei": HuggingFaceTEI,
1618
"continue-proxy": ContinueProxy,

core/control-plane/schema.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ const modelDescriptionSchema = z.object({
1919
"nebius",
2020
"siliconflow",
2121
"scaleway",
22+
"watsonx"
2223
]),
2324
model: z.string(),
2425
apiKey: z.string().optional(),
@@ -88,6 +89,7 @@ const embeddingsProviderSchema = z.object({
8889
"nebius",
8990
"siliconflow",
9091
"scaleway",
92+
"watsonx"
9193
]),
9294
apiBase: z.string().optional(),
9395
apiKey: z.string().optional(),
@@ -109,7 +111,7 @@ const embeddingsProviderSchema = z.object({
109111
});
110112

111113
const rerankerSchema = z.object({
112-
name: z.enum(["cohere", "voyage", "llm"]),
114+
name: z.enum(["cohere", "voyage", "llm", "watsonx"]),
113115
params: z.record(z.any()).optional(),
114116
});
115117

core/llm/llms/WatsonX.ts

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js";
1+
import { ChatMessage, Chunk, CompletionOptions, LLMOptions } from "../../index.js";
22
import { renderChatMessage } from "../../util/messageContent.js";
33
import { BaseLLM } from "../index.js";
44
import { streamResponse } from "../stream.js";
@@ -138,6 +138,26 @@ class WatsonX extends BaseLLM {
138138
};
139139
}
140140

141+
protected async updateWatsonxToken() {
142+
var now = new Date().getTime() / 1000;
143+
if (
144+
watsonxToken === undefined ||
145+
now > watsonxToken.expiration ||
146+
watsonxToken.token === undefined
147+
) {
148+
watsonxToken = await this.getBearerToken();
149+
} else {
150+
console.log(
151+
`Reusing token (expires in ${
152+
(watsonxToken.expiration - now) / 60
153+
} mins)`,
154+
);
155+
}
156+
if (watsonxToken.token === undefined) {
157+
throw new Error("Something went wrong. Check your credentials, please.");
158+
}
159+
}
160+
141161
protected async _complete(
142162
prompt: string,
143163
signal: AbortSignal,
@@ -174,23 +194,8 @@ class WatsonX extends BaseLLM {
174194
signal: AbortSignal,
175195
options: CompletionOptions,
176196
): AsyncGenerator<ChatMessage> {
177-
var now = new Date().getTime() / 1000;
178-
if (
179-
watsonxToken === undefined ||
180-
now > watsonxToken.expiration ||
181-
watsonxToken.token === undefined
182-
) {
183-
watsonxToken = await this.getBearerToken();
184-
} else {
185-
console.log(
186-
`Reusing token (expires in ${
187-
(watsonxToken.expiration - now) / 60
188-
} mins)`,
189-
);
190-
}
191-
if (watsonxToken.token === undefined) {
192-
throw new Error("Something went wrong. Check your credentials, please.");
193-
}
197+
await this.updateWatsonxToken();
198+
194199
const stopSequences =
195200
options.stop?.slice(0, 6) ??
196201
(options.model?.includes("granite") ? ["Question:"] : []);
@@ -267,20 +272,8 @@ class WatsonX extends BaseLLM {
267272
}
268273

269274
protected async _embed(chunks: string[]): Promise<number[][]> {
270-
var now = new Date().getTime() / 1000;
271-
if (
272-
watsonxToken === undefined ||
273-
now > watsonxToken.expiration ||
274-
watsonxToken.token === undefined
275-
) {
276-
watsonxToken = await this.getBearerToken();
277-
} else {
278-
console.log(
279-
`Reusing token (expires in ${
280-
(watsonxToken.expiration - now) / 60
281-
} mins)`,
282-
);
283-
}
275+
await this.updateWatsonxToken();
276+
284277
const payload: any = {
285278
inputs: chunks,
286279
parameters: {
@@ -318,6 +311,66 @@ class WatsonX extends BaseLLM {
318311
}
319312
return embeddings.map((e: any) => e.embedding);
320313
}
314+
315+
316+
async rerank(query: string, chunks: Chunk[]): Promise<number[]> {
317+
if (!query || !chunks.length) {
318+
throw new Error("Query and chunks must not be empty");
319+
}
320+
try {
321+
await this.updateWatsonxToken();
322+
323+
const headers = {
324+
"Content-Type": "application/json",
325+
Authorization: `${
326+
watsonxToken.expiration === -1 ? "ZenApiKey" : "Bearer"
327+
} ${watsonxToken.token}`,
328+
};
329+
330+
const payload: any = {
331+
inputs: chunks.map((chunk) => ({ text: chunk.content })),
332+
query: query,
333+
parameters: {
334+
truncate_input_tokens: 500,
335+
return_options: {
336+
top_n: chunks.length
337+
},
338+
},
339+
model_id: this.model,
340+
project_id: this.projectId,
341+
};
342+
343+
const resp = await this.fetch(
344+
new URL(
345+
`${this.apiBase}/ml/v1/text/rerank?version=${this.apiVersion}`,
346+
),
347+
{
348+
method: "POST",
349+
headers: headers,
350+
body: JSON.stringify(payload),
351+
},
352+
);
353+
354+
if (!resp.ok) {
355+
throw new Error(`Failed to rerank chunks: ${await resp.text()}`);
356+
}
357+
const data = await resp.json();
358+
const ranking = data.results;
359+
360+
if (!ranking) {
361+
throw new Error("Empty response received from Watsonx");
362+
}
363+
364+
// Sort results by index to maintain original order
365+
return ranking
366+
.sort((a: any, b: any) => a.index - b.index)
367+
.map((result: any) => result.score);
368+
369+
} catch (error) {
370+
console.error("Error in WatsonxReranker.rerank:", error);
371+
throw error;
372+
}
373+
}
321374
}
322375

323376
export default WatsonX;

docs/docs/customize/model-providers/more/watsonx.mdx

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,37 @@ To view the list of available embeddings models, visit [this page](https://datap
244244
```
245245
</TabItem>
246246
</Tabs>
247+
248+
249+
## Reranker
250+
251+
<Tabs groupId="config-example">
252+
<TabItem value="yaml" label="YAML">
253+
```yaml title="config.yaml"
254+
models:
255+
- name: Watsonx Reranker
256+
provider: watsonx
257+
model: cross-encoder/ms-marco-minilm-l-12-v2
258+
apiBase: https://us-south.ml.cloud.ibm.com
259+
projectId: PROJECT_ID
260+
apiKey: API_KEY/ZENAPI_KEY/USERNAME:PASSWORD
261+
apiVersion: 2024-03-14
262+
```
263+
</TabItem>
264+
<TabItem value="json" label="JSON">
265+
```json title="config.json"
266+
{
267+
"reranker": {
268+
"name": "watsonx",
269+
"params": {
270+
"model": "cross-encoder/ms-marco-minilm-l-12-v2",
271+
"apiBase": "watsonx endpoint e.g. https://us-south.ml.cloud.ibm.com",
272+
"projectId": "PROJECT_ID",
273+
"apiKey": "API_KEY/ZENAPI_KEY/USERNAME:PASSWORD",
274+
"apiVersion": "2024-03-14"
275+
}
276+
}
277+
}
278+
```
279+
</TabItem>
280+
</Tabs>

extensions/vscode/config_schema.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2881,6 +2881,7 @@
28812881
"bedrock",
28822882
"cohere",
28832883
"voyage",
2884+
"watsonx",
28842885
"llm",
28852886
"free-trial",
28862887
"huggingface-tei"

packages/config-types/src/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ export const modelDescriptionSchema = z.object({
5858
"continue-proxy",
5959
"nebius",
6060
"scaleway",
61+
"watsonx"
6162
]),
6263
model: z.string(),
6364
apiKey: z.string().optional(),
@@ -111,6 +112,7 @@ export const embeddingsProviderSchema = z.object({
111112
"continue-proxy",
112113
"nebius",
113114
"scaleway",
115+
"watsonx"
114116
]),
115117
apiBase: z.string().optional(),
116118
apiKey: z.string().optional(),
@@ -173,7 +175,7 @@ export const contextProviderSchema = z.object({
173175
export type ContextProvider = z.infer<typeof contextProviderSchema>;
174176

175177
export const rerankerSchema = z.object({
176-
name: z.enum(["cohere", "voyage", "llm", "continue-proxy"]),
178+
name: z.enum(["cohere", "voyage", "watsonx", "llm", "continue-proxy"]),
177179
params: z.record(z.any()).optional(),
178180
});
179181
export type Reranker = z.infer<typeof rerankerSchema>;

0 commit comments

Comments
 (0)