Skip to content

Commit 1386eb6

Browse files
authored
Nate/better retrieval (#1677)
* deduplicatearray tests * break out separate retrieval pipelines * IConfigHandler * tests for codebase indexer * better .continueignore for continue * indexing fixes * ignore .gitignore and .continueignore when indexing * retrieval pipeline improvements
1 parent f4198a4 commit 1386eb6

39 files changed

+704
-216
lines changed

.continueignore

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
**/*.run.xml
2-
archive/**/*
3-
extensions/vscode/models/**/*
4-
docs/docs/languages
1+
\*_/_.run.xml
2+
docs/docs/languages
3+
.changes/
4+
.idea/
5+
.vscode/
6+
.archive/
7+
**/*.scm

.prettierignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
extensions/vscode/continue_rc_schema.json
1+
extensions/vscode/continue_rc_schema.json
2+
**/.continueignore

binary/test/binary.test.ts

+5-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ describe("Test Suite", () => {
7373
);
7474
}
7575

76-
const ide = new FileSystemIde();
76+
const testDir = path.join(__dirname, "..", ".test");
77+
if (!fs.existsSync(testDir)) {
78+
fs.mkdirSync(testDir);
79+
}
80+
const ide = new FileSystemIde(testDir);
7781
const reverseIde = new ReverseMessageIde(messenger.on.bind(messenger), ide);
7882

7983
// Wait for core to set itself up

core/autocomplete/completionProvider.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import OpenAI from "openai";
44
import path from "path";
55
import { v4 as uuidv4 } from "uuid";
66
import { RangeInFileWithContents } from "../commands/util.js";
7-
import { ConfigHandler } from "../config/handler.js";
7+
import { IConfigHandler } from "../config/IConfigHandler.js";
88
import { TRIAL_FIM_MODEL } from "../config/onboarding.js";
99
import { streamLines } from "../diff/util.js";
1010
import {
@@ -145,7 +145,7 @@ export class CompletionProvider {
145145
private static lastUUID: string | undefined = undefined;
146146

147147
constructor(
148-
private readonly configHandler: ConfigHandler,
148+
private readonly configHandler: IConfigHandler,
149149
private readonly ide: IDE,
150150
private readonly getLlm: () => Promise<ILLM | undefined>,
151151
private readonly _onError: (e: any) => void,

core/config/handler.ts renamed to core/config/ConfigHandler.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ import {
88
ILLM,
99
} from "../index.js";
1010
import { Telemetry } from "../util/posthog.js";
11+
import { IConfigHandler } from "./IConfigHandler.js";
1112
import { finalToBrowserConfig, loadFullConfigNode } from "./load.js";
1213

13-
export class ConfigHandler {
14+
export class ConfigHandler implements IConfigHandler {
1415
private savedConfig: ContinueConfig | undefined;
1516
private savedBrowserConfig?: BrowserSerializedContinueConfig;
1617
private additionalContextProviders: IContextProvider[] = [];

core/config/IConfigHandler.ts

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import {
2+
BrowserSerializedContinueConfig,
3+
ContinueConfig,
4+
IContextProvider,
5+
IdeSettings,
6+
ILLM,
7+
} from "../index.js";
8+
9+
export interface IConfigHandler {
10+
updateIdeSettings(ideSettings: IdeSettings): void;
11+
onConfigUpdate(listener: (newConfig: ContinueConfig) => void): void;
12+
reloadConfig(): Promise<void>;
13+
getSerializedConfig(): Promise<BrowserSerializedContinueConfig>;
14+
loadConfig(): Promise<ContinueConfig>;
15+
llmFromTitle(title?: string): Promise<ILLM>;
16+
registerCustomContextProvider(contextProvider: IContextProvider): void;
17+
}

core/context/retrieval/fullTextSearch.ts

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { BranchAndDir, Chunk } from "../../index.js";
22
import { FullTextSearchCodebaseIndex } from "../../indexing/FullTextSearch.js";
3+
34
export async function retrieveFts(
45
query: string,
56
n: number,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import {
2+
BranchAndDir,
3+
Chunk,
4+
EmbeddingsProvider,
5+
IDE,
6+
Reranker,
7+
} from "../../..";
8+
import { LanceDbIndex } from "../../../indexing/LanceDbIndex";
9+
import { retrieveFts } from "../fullTextSearch";
10+
11+
export interface RetrievalPipelineOptions {
12+
ide: IDE;
13+
embeddingsProvider: EmbeddingsProvider;
14+
reranker: Reranker | undefined;
15+
16+
input: string;
17+
nRetrieve: number;
18+
nFinal: number;
19+
tags: BranchAndDir[];
20+
filterDirectory?: string;
21+
}
22+
23+
export interface IRetrievalPipeline {
24+
run(options: RetrievalPipelineOptions): Promise<Chunk[]>;
25+
}
26+
27+
export default class BaseRetrievalPipeline implements IRetrievalPipeline {
28+
private lanceDbIndex: LanceDbIndex;
29+
constructor(protected readonly options: RetrievalPipelineOptions) {
30+
this.lanceDbIndex = new LanceDbIndex(options.embeddingsProvider, (path) =>
31+
options.ide.readFile(path),
32+
);
33+
}
34+
35+
protected async retrieveFts(input: string, n: number): Promise<Chunk[]> {
36+
return retrieveFts(
37+
input,
38+
n,
39+
this.options.tags,
40+
this.options.filterDirectory,
41+
);
42+
}
43+
44+
protected async retrieveEmbeddings(
45+
input: string,
46+
n: number,
47+
): Promise<Chunk[]> {
48+
return this.lanceDbIndex.retrieve(
49+
input,
50+
n,
51+
this.options.tags,
52+
this.options.filterDirectory,
53+
);
54+
}
55+
56+
run(): Promise<Chunk[]> {
57+
throw new Error("Not implemented");
58+
}
59+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import { Chunk } from "../../..";
2+
import { deduplicateChunks } from "../util";
3+
import BaseRetrievalPipeline from "./BaseRetrievalPipeline";
4+
5+
export default class NoRerankerRetrievalPipeline extends BaseRetrievalPipeline {
6+
async run(): Promise<Chunk[]> {
7+
const { input } = this.options;
8+
9+
// Get all retrieval results
10+
const retrievalResults: Chunk[] = [];
11+
12+
// Full-text search
13+
const ftsResults = await this.retrieveFts(input, this.options.nFinal / 2);
14+
retrievalResults.push(...ftsResults);
15+
16+
// Embeddings
17+
const embeddingResults = await this.retrieveEmbeddings(
18+
input,
19+
this.options.nFinal / 2,
20+
);
21+
retrievalResults.push(...embeddingResults);
22+
23+
const finalResults: Chunk[] = deduplicateChunks(retrievalResults);
24+
return finalResults;
25+
}
26+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import { Chunk } from "../../..";
2+
import { RETRIEVAL_PARAMS } from "../../../util/parameters";
3+
import { deduplicateChunks } from "../util";
4+
import BaseRetrievalPipeline from "./BaseRetrievalPipeline";
5+
6+
export default class RerankerRetrievalPipeline extends BaseRetrievalPipeline {
7+
private async _retrieveInitial(): Promise<Chunk[]> {
8+
const { input, nRetrieve } = this.options;
9+
10+
// Get all retrieval results
11+
const retrievalResults: Chunk[] = [];
12+
13+
// Full-text search
14+
const ftsResults = await this.retrieveFts(input, nRetrieve / 2);
15+
retrievalResults.push(...ftsResults);
16+
17+
// Embeddings
18+
const embeddingResults = await this.retrieveEmbeddings(input, nRetrieve);
19+
retrievalResults.push(
20+
...embeddingResults.slice(0, nRetrieve - ftsResults.length),
21+
);
22+
23+
const results: Chunk[] = deduplicateChunks(retrievalResults);
24+
return results;
25+
}
26+
27+
private async _rerank(input: string, chunks: Chunk[]): Promise<Chunk[]> {
28+
if (!this.options.reranker) {
29+
throw new Error("No reranker provided");
30+
}
31+
32+
let scores: number[] = await this.options.reranker.rerank(input, chunks);
33+
34+
// Filter out low-scoring results
35+
let results = chunks;
36+
// let results = chunks.filter(
37+
// (_, i) => scores[i] >= RETRIEVAL_PARAMS.rerankThreshold,
38+
// );
39+
// scores = scores.filter(
40+
// (score) => score >= RETRIEVAL_PARAMS.rerankThreshold,
41+
// );
42+
43+
results.sort(
44+
(a, b) => scores[results.indexOf(a)] - scores[results.indexOf(b)],
45+
);
46+
results = results.slice(-this.options.nFinal);
47+
return results;
48+
}
49+
50+
private async _expandWithEmbeddings(chunks: Chunk[]): Promise<Chunk[]> {
51+
const topResults = chunks.slice(
52+
-RETRIEVAL_PARAMS.nResultsToExpandWithEmbeddings,
53+
);
54+
55+
const expanded = await Promise.all(
56+
topResults.map(async (chunk, i) => {
57+
const results = await this.retrieveEmbeddings(
58+
chunk.content,
59+
RETRIEVAL_PARAMS.nEmbeddingsExpandTo,
60+
);
61+
return results;
62+
}),
63+
);
64+
return expanded.flat();
65+
}
66+
67+
private async _expandRankedResults(chunks: Chunk[]): Promise<Chunk[]> {
68+
let results: Chunk[] = [];
69+
70+
const embeddingsResults = await this._expandWithEmbeddings(chunks);
71+
results.push(...embeddingsResults);
72+
73+
return results;
74+
}
75+
76+
async run(): Promise<Chunk[]> {
77+
// Retrieve initial results
78+
let results = await this._retrieveInitial();
79+
80+
// Rerank
81+
const { input } = this.options;
82+
results = await this._rerank(input, results);
83+
84+
// // // Expand top reranked results
85+
// const expanded = await this._expandRankedResults(results);
86+
// results.push(...expanded);
87+
88+
// // De-duplicate
89+
// results = deduplicateChunks(results);
90+
91+
// // Rerank again
92+
// results = await this._rerank(input, results);
93+
94+
// TODO: stitch together results
95+
96+
return results;
97+
}
98+
}
99+
100+
// Source: expansion with code graph
101+
// consider doing this after reranking? Or just having a lower reranking threshold
102+
// This is VS Code only until we use PSI for JetBrains or build our own general solution
103+
// TODO: Need to pass in the expandSnippet function as a function argument
104+
// because this import causes `tsc` to fail
105+
// if ((await extras.ide.getIdeInfo()).ideType === "vscode") {
106+
// const { expandSnippet } = await import(
107+
// "../../../extensions/vscode/src/util/expandSnippet"
108+
// );
109+
// let expansionResults = (
110+
// await Promise.all(
111+
// extras.selectedCode.map(async (rif) => {
112+
// return expandSnippet(
113+
// rif.filepath,
114+
// rif.range.start.line,
115+
// rif.range.end.line,
116+
// extras.ide,
117+
// );
118+
// }),
119+
// )
120+
// ).flat() as Chunk[];
121+
// retrievalResults.push(...expansionResults);
122+
// }
123+
124+
// Source: Open file exact match
125+
// Source: Class/function name exact match

0 commit comments

Comments
 (0)