Skip to content

Nate/better retrieval #1677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .continueignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
**/*.run.xml
archive/**/*
extensions/vscode/models/**/*
docs/docs/languages
\*_/_.run.xml
docs/docs/languages
.changes/
.idea/
.vscode/
.archive/
**/*.scm
3 changes: 2 additions & 1 deletion .prettierignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
extensions/vscode/continue_rc_schema.json
extensions/vscode/continue_rc_schema.json
**/.continueignore
6 changes: 5 additions & 1 deletion binary/test/binary.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ describe("Test Suite", () => {
);
}

const ide = new FileSystemIde();
const testDir = path.join(__dirname, "..", ".test");
if (!fs.existsSync(testDir)) {
fs.mkdirSync(testDir);
}
const ide = new FileSystemIde(testDir);
const reverseIde = new ReverseMessageIde(messenger.on.bind(messenger), ide);

// Wait for core to set itself up
Expand Down
4 changes: 2 additions & 2 deletions core/autocomplete/completionProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import OpenAI from "openai";
import path from "path";
import { v4 as uuidv4 } from "uuid";
import { RangeInFileWithContents } from "../commands/util.js";
import { ConfigHandler } from "../config/handler.js";
import { IConfigHandler } from "../config/IConfigHandler.js";
import { TRIAL_FIM_MODEL } from "../config/onboarding.js";
import { streamLines } from "../diff/util.js";
import {
Expand Down Expand Up @@ -145,7 +145,7 @@ export class CompletionProvider {
private static lastUUID: string | undefined = undefined;

constructor(
private readonly configHandler: ConfigHandler,
private readonly configHandler: IConfigHandler,
private readonly ide: IDE,
private readonly getLlm: () => Promise<ILLM | undefined>,
private readonly _onError: (e: any) => void,
Expand Down
3 changes: 2 additions & 1 deletion core/config/handler.ts → core/config/ConfigHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import {
ILLM,
} from "../index.js";
import { Telemetry } from "../util/posthog.js";
import { IConfigHandler } from "./IConfigHandler.js";
import { finalToBrowserConfig, loadFullConfigNode } from "./load.js";

export class ConfigHandler {
export class ConfigHandler implements IConfigHandler {
private savedConfig: ContinueConfig | undefined;
private savedBrowserConfig?: BrowserSerializedContinueConfig;
private additionalContextProviders: IContextProvider[] = [];
Expand Down
17 changes: 17 additions & 0 deletions core/config/IConfigHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import {
BrowserSerializedContinueConfig,
ContinueConfig,
IContextProvider,
IdeSettings,
ILLM,
} from "../index.js";

export interface IConfigHandler {
updateIdeSettings(ideSettings: IdeSettings): void;
onConfigUpdate(listener: (newConfig: ContinueConfig) => void): void;
reloadConfig(): Promise<void>;
getSerializedConfig(): Promise<BrowserSerializedContinueConfig>;
loadConfig(): Promise<ContinueConfig>;
llmFromTitle(title?: string): Promise<ILLM>;
registerCustomContextProvider(contextProvider: IContextProvider): void;
}
1 change: 1 addition & 0 deletions core/context/retrieval/fullTextSearch.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { BranchAndDir, Chunk } from "../../index.js";
import { FullTextSearchCodebaseIndex } from "../../indexing/FullTextSearch.js";

export async function retrieveFts(
query: string,
n: number,
Expand Down
59 changes: 59 additions & 0 deletions core/context/retrieval/pipelines/BaseRetrievalPipeline.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import {
BranchAndDir,
Chunk,
EmbeddingsProvider,
IDE,
Reranker,
} from "../../..";
import { LanceDbIndex } from "../../../indexing/LanceDbIndex";
import { retrieveFts } from "../fullTextSearch";

export interface RetrievalPipelineOptions {
ide: IDE;
embeddingsProvider: EmbeddingsProvider;
reranker: Reranker | undefined;

input: string;
nRetrieve: number;
nFinal: number;
tags: BranchAndDir[];
filterDirectory?: string;
}

export interface IRetrievalPipeline {
run(options: RetrievalPipelineOptions): Promise<Chunk[]>;
}

export default class BaseRetrievalPipeline implements IRetrievalPipeline {
private lanceDbIndex: LanceDbIndex;
constructor(protected readonly options: RetrievalPipelineOptions) {
this.lanceDbIndex = new LanceDbIndex(options.embeddingsProvider, (path) =>
options.ide.readFile(path),
);
}

protected async retrieveFts(input: string, n: number): Promise<Chunk[]> {
return retrieveFts(
input,
n,
this.options.tags,
this.options.filterDirectory,
);
}

protected async retrieveEmbeddings(
input: string,
n: number,
): Promise<Chunk[]> {
return this.lanceDbIndex.retrieve(
input,
n,
this.options.tags,
this.options.filterDirectory,
);
}

run(): Promise<Chunk[]> {
throw new Error("Not implemented");
}
}
26 changes: 26 additions & 0 deletions core/context/retrieval/pipelines/NoRerankerRetrievalPipeline.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { Chunk } from "../../..";
import { deduplicateChunks } from "../util";
import BaseRetrievalPipeline from "./BaseRetrievalPipeline";

export default class NoRerankerRetrievalPipeline extends BaseRetrievalPipeline {
async run(): Promise<Chunk[]> {
const { input } = this.options;

// Get all retrieval results
const retrievalResults: Chunk[] = [];

// Full-text search
const ftsResults = await this.retrieveFts(input, this.options.nFinal / 2);
retrievalResults.push(...ftsResults);

// Embeddings
const embeddingResults = await this.retrieveEmbeddings(
input,
this.options.nFinal / 2,
);
retrievalResults.push(...embeddingResults);

const finalResults: Chunk[] = deduplicateChunks(retrievalResults);
return finalResults;
}
}
125 changes: 125 additions & 0 deletions core/context/retrieval/pipelines/RerankerRetrievalPipeline.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import { Chunk } from "../../..";
import { RETRIEVAL_PARAMS } from "../../../util/parameters";
import { deduplicateChunks } from "../util";
import BaseRetrievalPipeline from "./BaseRetrievalPipeline";

export default class RerankerRetrievalPipeline extends BaseRetrievalPipeline {
private async _retrieveInitial(): Promise<Chunk[]> {
const { input, nRetrieve } = this.options;

// Get all retrieval results
const retrievalResults: Chunk[] = [];

// Full-text search
const ftsResults = await this.retrieveFts(input, nRetrieve / 2);
retrievalResults.push(...ftsResults);

// Embeddings
const embeddingResults = await this.retrieveEmbeddings(input, nRetrieve);
retrievalResults.push(
...embeddingResults.slice(0, nRetrieve - ftsResults.length),
);

const results: Chunk[] = deduplicateChunks(retrievalResults);
return results;
}

private async _rerank(input: string, chunks: Chunk[]): Promise<Chunk[]> {
if (!this.options.reranker) {
throw new Error("No reranker provided");
}

let scores: number[] = await this.options.reranker.rerank(input, chunks);

// Filter out low-scoring results
let results = chunks;
// let results = chunks.filter(
// (_, i) => scores[i] >= RETRIEVAL_PARAMS.rerankThreshold,
Copy link
Contributor

@AnnoyingTechnology AnnoyingTechnology Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been commented out, is there a reason for that ?
I see the threshold is set to 0.3, which seems to high.
Setting it to 0.1 and de-commenting this section would at least remove chucks that are totally irrelevant.

Having the rerankThreshold configurable in the config.conf's reranker section would also be better than a hard-coded value.

// );
// scores = scores.filter(
// (score) => score >= RETRIEVAL_PARAMS.rerankThreshold,
// );

results.sort(
(a, b) => scores[results.indexOf(a)] - scores[results.indexOf(b)],
);
results = results.slice(-this.options.nFinal);
return results;
}

private async _expandWithEmbeddings(chunks: Chunk[]): Promise<Chunk[]> {
const topResults = chunks.slice(
-RETRIEVAL_PARAMS.nResultsToExpandWithEmbeddings,
);

const expanded = await Promise.all(
topResults.map(async (chunk, i) => {
const results = await this.retrieveEmbeddings(
chunk.content,
RETRIEVAL_PARAMS.nEmbeddingsExpandTo,
);
return results;
}),
);
return expanded.flat();
}

private async _expandRankedResults(chunks: Chunk[]): Promise<Chunk[]> {
let results: Chunk[] = [];

const embeddingsResults = await this._expandWithEmbeddings(chunks);
results.push(...embeddingsResults);

return results;
}

async run(): Promise<Chunk[]> {
// Retrieve initial results
let results = await this._retrieveInitial();

// Rerank
const { input } = this.options;
results = await this._rerank(input, results);

// // // Expand top reranked results
// const expanded = await this._expandRankedResults(results);
// results.push(...expanded);

// // De-duplicate
// results = deduplicateChunks(results);

// // Rerank again
// results = await this._rerank(input, results);

// TODO: stitch together results

return results;
}
}

// Source: expansion with code graph
// consider doing this after reranking? Or just having a lower reranking threshold
// This is VS Code only until we use PSI for JetBrains or build our own general solution
// TODO: Need to pass in the expandSnippet function as a function argument
// because this import causes `tsc` to fail
// if ((await extras.ide.getIdeInfo()).ideType === "vscode") {
// const { expandSnippet } = await import(
// "../../../extensions/vscode/src/util/expandSnippet"
// );
// let expansionResults = (
// await Promise.all(
// extras.selectedCode.map(async (rif) => {
// return expandSnippet(
// rif.filepath,
// rif.range.start.line,
// rif.range.end.line,
// extras.ide,
// );
// }),
// )
// ).flat() as Chunk[];
// retrievalResults.push(...expansionResults);
// }

// Source: Open file exact match
// Source: Class/function name exact match
Loading
Loading