Skip to content

Commit 11fdace

Browse files
committed
Merge branch 'main' into feat/watsonx_reranker_integration
2 parents 123971f + 58d62fa commit 11fdace

File tree

23 files changed

+335
-129
lines changed

23 files changed

+335
-129
lines changed

.github/workflows/pr_checks.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,13 @@ jobs:
303303
cd extensions/vscode
304304
npm ci
305305
npm run e2e:compile
306-
FILES=$(ls -1 e2e/_output/tests/*.test.js | jq -R . | jq -s .)
306+
if [[ "${{ github.event.pull_request.head.repo.fork }}" == "true" || "${{ github.actor }}" == "dependabot[bot]" ]]; then
307+
# Exclude SSH tests for forks
308+
FILES=$(ls -1 e2e/_output/tests/*.test.js | grep -v "SSH" | jq -R . | jq -s .)
309+
else
310+
# Include all tests for non-forks
311+
FILES=$(ls -1 e2e/_output/tests/*.test.js | jq -R . | jq -s .)
312+
fi
307313
echo "test_file_matrix<<EOF" >> $GITHUB_OUTPUT
308314
echo "$FILES" >> $GITHUB_OUTPUT
309315
echo "EOF" >> $GITHUB_OUTPUT

core/config/load.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ async function intermediateToFinalConfig(
272272
{
273273
...desc,
274274
model: modelName,
275-
title: `${llm.title} - ${modelName}`,
275+
title: modelName,
276276
},
277277
ide.readFile.bind(ide),
278278
uniqueId,

core/config/yaml/models.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ async function autodetectModels(
127127
{
128128
...model,
129129
model: modelName,
130-
name: `${llm.title} - ${modelName}`,
130+
name: modelName,
131131
},
132132
uniqueId,
133133
ideSettings,

core/index.d.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ export interface ILLM extends LLMOptions {
163163
): string | ChatMessage[];
164164
}
165165

166+
export interface ModelInstaller {
167+
installModel(modelName: string, signal: AbortSignal, progressReporter?: (task: string, increment: number, total: number) => void): Promise<any>;
168+
}
169+
166170
export type ContextProviderType = "normal" | "query" | "submenu";
167171

168172
export interface ContextProviderDescription {
@@ -543,7 +547,7 @@ export interface CustomLLMWithOptionals {
543547
signal: AbortSignal,
544548
options: CompletionOptions,
545549
fetch: (input: RequestInfo | URL, init?: RequestInit) => Promise<Response>,
546-
) => AsyncGenerator<string>;
550+
) => AsyncGenerator<ChatMessage | string>;
547551
listModels?: (
548552
fetch: (input: RequestInfo | URL, init?: RequestInit) => Promise<Response>,
549553
) => Promise<string[]>;

core/llm/index.ts

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import {
1919
LLMFullCompletionOptions,
2020
LLMOptions,
2121
ModelCapability,
22+
ModelInstaller,
2223
PromptLog,
2324
PromptTemplate,
2425
RequestOptions,
@@ -58,6 +59,17 @@ import {
5859
toFimBody,
5960
} from "./openaiTypeConverters.js";
6061

62+
63+
export class LLMError extends Error {
64+
constructor(message: string, public llm: ILLM) {
65+
super(message);
66+
}
67+
}
68+
69+
export function isModelInstaller(provider: any): provider is ModelInstaller {
70+
return provider && typeof provider.installModel === 'function';
71+
}
72+
6173
export abstract class BaseLLM implements ILLM {
6274
static providerName: string;
6375
static defaultOptions: Partial<LLMOptions> | undefined = undefined;
@@ -380,9 +392,11 @@ export abstract class BaseLLM implements ILLM {
380392
if (!resp.ok) {
381393
let text = await resp.text();
382394
if (resp.status === 404 && !resp.url.includes("/v1")) {
383-
if (text.includes("try pulling it first")) {
384-
const model = JSON.parse(text).error.split(" ")[1].slice(1, -1);
395+
const error = JSON.parse(text)?.error?.replace(/"/g, "'");
396+
let model = error?.match(/model '(.*)' not found/)?.[1];
397+
if (model && resp.url.match("127.0.0.1:11434")) {
385398
text = `The model "${model}" was not found. To download it, run \`ollama run ${model}\`.`;
399+
throw new LLMError(text, this);// No need to add HTTP status details
386400
} else if (text.includes("/api/chat")) {
387401
text =
388402
"The /api/chat endpoint was not found. This may mean that you are using an older version of Ollama that does not support /api/chat. Upgrading to the latest version will solve the issue.";
@@ -442,6 +456,10 @@ export abstract class BaseLLM implements ILLM {
442456
throw new Error(message);
443457
}
444458
}
459+
//if e instance of LLMError, rethrow
460+
if (e instanceof LLMError) {
461+
throw e;
462+
}
445463
throw new Error(e.message);
446464
}
447465
};
@@ -763,6 +781,7 @@ export abstract class BaseLLM implements ILLM {
763781
}
764782

765783
let completion = "";
784+
let citations: null | string[] = null
766785

767786
try {
768787
if (this.templateMessages) {
@@ -790,6 +809,8 @@ export abstract class BaseLLM implements ILLM {
790809
completion = renderChatMessage(msg);
791810
} else {
792811
// Stream true
812+
console.log("Streaming");
813+
793814
const stream = this.openaiAdapter.chatCompletionStream(
794815
{
795816
...body,
@@ -802,6 +823,9 @@ export abstract class BaseLLM implements ILLM {
802823
if (result) {
803824
yield result;
804825
}
826+
if (!citations && (chunk as any).citations && Array.isArray((chunk as any).citations)) {
827+
citations = (chunk as any).citations;
828+
}
805829
}
806830
}
807831
} else {
@@ -824,6 +848,10 @@ export abstract class BaseLLM implements ILLM {
824848

825849
if (logEnabled && this.writeLog) {
826850
await this.writeLog(`Completion:\n${completion}\n\n`);
851+
852+
if (citations) {
853+
await this.writeLog(`Citations:\n${citations.map((c, i) => `${i + 1}: ${c}`).join("\n")}\n\n`);
854+
}
827855
}
828856

829857
return {

core/llm/llms/CustomLLM.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { ChatMessage, CompletionOptions, CustomLLM } from "../../index.js";
2+
import { renderChatMessage } from "../../util/messageContent.js";
23
import { BaseLLM } from "../index.js";
34

45
class CustomLLMClass extends BaseLLM {
@@ -18,7 +19,7 @@ class CustomLLMClass extends BaseLLM {
1819
signal: AbortSignal,
1920
options: CompletionOptions,
2021
fetch: (input: RequestInfo | URL, init?: RequestInit) => Promise<Response>,
21-
) => AsyncGenerator<string>;
22+
) => AsyncGenerator<ChatMessage | string>;
2223

2324
constructor(custom: CustomLLM) {
2425
super(custom.options || { model: "custom" });
@@ -38,7 +39,11 @@ class CustomLLMClass extends BaseLLM {
3839
options,
3940
(...args) => this.fetch(...args),
4041
)) {
41-
yield { role: "assistant", content };
42+
if (typeof content === "string") {
43+
yield { role: "assistant", content };
44+
} else {
45+
yield content;
46+
}
4247
}
4348
} else {
4449
for await (const update of super._streamChat(messages, signal, options)) {
@@ -68,7 +73,11 @@ class CustomLLMClass extends BaseLLM {
6873
options,
6974
(...args) => this.fetch(...args),
7075
)) {
71-
yield content;
76+
if (typeof content === "string") {
77+
yield content;
78+
} else {
79+
yield renderChatMessage(content);
80+
}
7281
}
7382
} else {
7483
throw new Error(

core/llm/llms/Ollama.ts

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import { JSONSchema7, JSONSchema7Object } from "json-schema";
22

3-
import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js";
3+
import { ChatMessage, CompletionOptions, ModelInstaller, LLMOptions } from "../../index.js";
44
import { renderChatMessage } from "../../util/messageContent.js";
55
import { BaseLLM } from "../index.js";
66
import { streamResponse } from "../stream.js";
7+
import { getRemoteModelInfo } from "../../util/ollamaHelper.js";
78

89
type OllamaChatMessage = {
910
role: "tool" | "user" | "assistant" | "system";
@@ -123,7 +124,7 @@ interface OllamaTool {
123124
};
124125
}
125126

126-
class Ollama extends BaseLLM {
127+
class Ollama extends BaseLLM implements ModelInstaller{
127128
static providerName = "ollama";
128129
static defaultOptions: Partial<LLMOptions> = {
129130
apiBase: "http://localhost:11434/",
@@ -574,6 +575,38 @@ class Ollama extends BaseLLM {
574575
}
575576
return embedding;
576577
}
578+
579+
public async installModel(modelName: string, signal: AbortSignal, progressReporter?: (task: string, increment: number, total: number) => void): Promise<any> {
580+
const modelInfo = await getRemoteModelInfo(modelName, signal);
581+
if (!modelInfo) {
582+
throw new Error(`'${modelName}' not found in the Ollama registry!`);
583+
}
584+
const response = await fetch(this.getEndpoint("api/pull"), {
585+
method: 'POST',
586+
headers: {
587+
'Content-Type': 'application/json',
588+
Authorization: `Bearer ${this.apiKey}`,
589+
},
590+
body: JSON.stringify({ name: modelName }),
591+
signal
592+
});
593+
594+
const reader = response.body?.getReader();
595+
//TODO: generate proper progress based on modelInfo size
596+
while (true) {
597+
const { done, value } = await reader?.read() || { done: true, value: undefined };
598+
if (done) {
599+
break;
600+
}
601+
602+
const chunk = new TextDecoder().decode(value);
603+
const lines = chunk.split('\n').filter(Boolean);
604+
for (const line of lines) {
605+
const data = JSON.parse(line);
606+
progressReporter?.(data.status, data.completed, data.total);
607+
}
608+
}
609+
}
577610
}
578611

579612
export default Ollama;

core/llm/toolSupport.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,18 @@ export const PROVIDER_TOOL_SUPPORT: Record<
5757
},
5858
// https://ollama.com/search?c=tools
5959
ollama: (model) => {
60+
let modelName = "";
61+
// Extract the model name after the last slash to support other registries
62+
if(model.includes("/")) {
63+
let parts = model.split('/');
64+
modelName = parts[parts.length - 1];
65+
} else {
66+
modelName = model;
67+
}
68+
6069
if (
6170
["vision", "math", "guard", "mistrallite", "mistral-openorca"].some(
62-
(part) => model.toLowerCase().includes(part),
71+
(part) => modelName.toLowerCase().includes(part),
6372
)
6473
) {
6574
return false;
@@ -79,10 +88,11 @@ export const PROVIDER_TOOL_SUPPORT: Record<
7988
"nemotron",
8089
"llama3-groq",
8190
"granite3",
91+
"granite-3",
8292
"aya-expanse",
8393
"firefunction-v2",
8494
"mistral",
85-
].some((part) => model.toLowerCase().startsWith(part))
95+
].some((part) => modelName.toLowerCase().includes(part))
8696
) {
8797
return true;
8898
}

core/util/ollamaHelper.ts

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
import crypto from "crypto";
2+
import { exec } from "node:child_process";
13
import path from "node:path";
24
import { IDE } from "..";
3-
import { exec } from "node:child_process";
5+
6+
export interface ModelInfo {
7+
id: string;
8+
size: number;
9+
digest: string;
10+
}
411

512
export async function isOllamaInstalled(): Promise<boolean> {
613
return new Promise((resolve, _reject) => {
@@ -39,3 +46,54 @@ export async function startLocalOllama(ide: IDE): Promise<any> {
3946
});
4047
}
4148
}
49+
50+
export async function getRemoteModelInfo(
51+
modelId: string,
52+
signal?: AbortSignal,
53+
): Promise<ModelInfo | undefined> {
54+
const start = Date.now();
55+
const [modelName, tag = "latest"] = modelId.split(":");
56+
const url = `https://registry.ollama.ai/v2/library/${modelName}/manifests/${tag}`;
57+
try {
58+
const sig = signal ? signal : AbortSignal.timeout(3000);
59+
const response = await fetch(url, { signal: sig });
60+
61+
if (!response.ok) {
62+
throw new Error(`Failed to fetch the model page: ${response.statusText}`);
63+
}
64+
65+
// First, read the response body as an ArrayBuffer to compute the digest
66+
const buffer = await response.arrayBuffer();
67+
const digest = getDigest(buffer);
68+
69+
// Then, decode the ArrayBuffer into a string and parse it as JSON
70+
const text = new TextDecoder().decode(buffer);
71+
const manifest = JSON.parse(text) as {
72+
config: { size: number };
73+
layers: { size: number }[];
74+
};
75+
const modelSize =
76+
manifest.config.size +
77+
manifest.layers.reduce((sum, layer) => sum + layer.size, 0);
78+
79+
const data: ModelInfo = {
80+
id: modelId,
81+
size: modelSize,
82+
digest,
83+
};
84+
// Cache the successful result
85+
return data;
86+
} catch (error) {
87+
console.error(`Error fetching or parsing model info: ${error}`);
88+
} finally {
89+
const elapsed = Date.now() - start;
90+
console.log(`Fetched remote information for ${modelId} in ${elapsed} ms`);
91+
}
92+
return undefined;
93+
}
94+
95+
function getDigest(buffer: ArrayBuffer): string {
96+
const hash = crypto.createHash("sha256");
97+
hash.update(new Uint8Array(buffer));
98+
return hash.digest("hex");
99+
}

core/util/text.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export const capitalizeFirstLetter = (val: string) => {
2+
if (val.length === 0) {
3+
return "";
4+
}
5+
return val[0].toUpperCase() + val.slice(1);
6+
};

0 commit comments

Comments
 (0)