Skip to content

Commit 294f600

Browse files
jeasonnowjeasonnowjacoblee93
authored
langchain[minor]: add EnsembleRetriever (#5556)
* langchain[patch]: add support to merge retrievers * Format * Parallelize, lint, format, small fixes * Add entrypoint * Fix import * Add docs and fix build artifacts --------- Co-authored-by: jeasonnow <[email protected]> Co-authored-by: jacoblee93 <[email protected]>
1 parent d35d12d commit 294f600

File tree

17 files changed

+349
-2
lines changed

17 files changed

+349
-2
lines changed

docs/core_docs/docs/concepts.mdx

+3-2
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,9 @@ LangChain provides several advanced retrieval types. A full list is below, along
672672
| [Multi Vector](/docs/how_to/multi_vector/) | Vectorstore + Document Store | Sometimes during indexing | If you are able to extract information from documents that you think is more relevant to index than the text itself. | This involves creating multiple vectors for each document. Each vector could be created in a myriad of ways - examples include summaries of the text and hypothetical questions. |
673673
| [Self Query](/docs/how_to/self_query/) | Vectorstore | Yes | If users are asking questions that are better answered by fetching documents based on metadata rather than similarity with the text. | This uses an LLM to transform user input into two things: (1) a string to look up semantically, (2) a metadata filer to go along with it. This is useful because oftentimes questions are about the METADATA of documents (not the content itself). |
674674
| [Contextual Compression](/docs/how_to/contextual_compression/) | Any | Sometimes | If you are finding that your retrieved documents contain too much irrelevant information and are distracting the LLM. | This puts a post-processing step on top of another retriever and extracts only the most relevant information from retrieved documents. This can be done with embeddings or an LLM. |
675-
| [Time-Weighted Vectorstore](/docs/how_to/time_weighted_vectorstore/) | Vectorstore | No | If you have timestamps associated with your documents, and you want to retrieve the most recent ones | This fetches documents based on a combination of semantic similarity (as in normal vector retrieval) and recency (looking at timestamps of indexed documents) |
676-
| [Multi-Query Retriever](/docs/how_to/multiple_queries/) | Any | Yes | If users are asking questions that are complex and require multiple pieces of distinct information to respond | This uses an LLM to generate multiple queries from the original one. This is useful when the original query needs pieces of information about multiple topics to be properly answered. By generating multiple queries, we can then fetch documents for each of them. |
675+
| [Time-Weighted Vectorstore](/docs/how_to/time_weighted_vectorstore/) | Vectorstore | No | If you have timestamps associated with your documents, and you want to retrieve the most recent ones. | This fetches documents based on a combination of semantic similarity (as in normal vector retrieval) and recency (looking at timestamps of indexed documents) |
676+
| [Multi-Query Retriever](/docs/how_to/multiple_queries/) | Any | Yes | If users are asking questions that are complex and require multiple pieces of distinct information to respond. | This uses an LLM to generate multiple queries from the original one. This is useful when the original query needs pieces of information about multiple topics to be properly answered. By generating multiple queries, we can then fetch documents for each of them. |
677+
| [Ensemble](/docs/how_to/ensemble_retriever) | Any | No | If you have multiple retrieval methods and want to try combining them. | This fetches documents from multiple retrievers and then combines them. |
677678

678679
### Text splitting
679680

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# How to combine results from multiple retrievers
2+
3+
:::info Prerequisites
4+
5+
This guide assumes familiarity with the following concepts:
6+
7+
- [Documents](/docs/concepts#document)
8+
- [Retrievers](/docs/concepts#retrievers)
9+
10+
:::
11+
12+
The [EnsembleRetriever](https://api.js.langchain.com/classes/langchain_retrievers_ensemble.EnsembleRetriever.html) supports ensembling of results from multiple retrievers. It is initialized with a list of [BaseRetriever](https://api.js.langchain.com/classes/langchain_core_retrievers.BaseRetriever.html) objects. EnsembleRetrievers rerank the results of the constituent retrievers based on the [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) algorithm.
13+
14+
By leveraging the strengths of different algorithms, the `EnsembleRetriever` can achieve better performance than any single algorithm.
15+
16+
One useful pattern is to combine a keyword matching retriever with a dense retriever (like embedding similarity), because their strengths are complementary. This can be considered a form of "hybrid search". The sparse retriever is good at finding relevant documents based on keywords, while the dense retriever is good at finding relevant documents based on semantic similarity.
17+
18+
Below we demonstrate ensembling of a [simple custom retriever](/docs/how_to/custom_retriever/) that simply returns documents that directly contain the input query with a retriever derived from a [demo, in-memory, vector store](https://api.js.langchain.com/classes/langchain_vectorstores_memory.MemoryVectorStore.html).
19+
20+
import CodeBlock from "@theme/CodeBlock";
21+
import Example from "@examples/retrievers/ensemble_retriever.ts";
22+
23+
<CodeBlock language="typescript">{Example}</CodeBlock>
24+
25+
## Next steps
26+
27+
You've now learned how to combine results from multiple retrievers.
28+
Next, check out some other retrieval how-to guides, such as how to [improve results using multiple embeddings per document](/docs/how_to/multi_vector)
29+
or how to [create your own custom retriever](/docs/how_to/custom_retriever).

docs/core_docs/docs/how_to/index.mdx

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ Retrievers are responsible for taking a query and returning relevant documents.
133133
- [How to: generate multiple queries to retrieve data for](/docs/how_to/multiple_queries)
134134
- [How to: use contextual compression to compress the data retrieved](/docs/how_to/contextual_compression)
135135
- [How to: write a custom retriever class](/docs/how_to/custom_retriever)
136+
- [How to: combine the results from multiple retrievers](/docs/how_to/ensemble_retriever)
136137
- [How to: generate multiple embeddings per document](/docs/how_to/multi_vector)
137138
- [How to: retrieve the whole document for a chunk](/docs/how_to/parent_document_retriever)
138139
- [How to: generate metadata filters](/docs/how_to/self_query)

environment_tests/test-exports-bun/src/entrypoints.js

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export * from "langchain/callbacks";
3737
export * from "langchain/output_parsers";
3838
export * from "langchain/retrievers/contextual_compression";
3939
export * from "langchain/retrievers/document_compressors";
40+
export * from "langchain/retrievers/ensemble";
4041
export * from "langchain/retrievers/multi_query";
4142
export * from "langchain/retrievers/multi_vector";
4243
export * from "langchain/retrievers/parent_document";

environment_tests/test-exports-cf/src/entrypoints.js

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export * from "langchain/callbacks";
3737
export * from "langchain/output_parsers";
3838
export * from "langchain/retrievers/contextual_compression";
3939
export * from "langchain/retrievers/document_compressors";
40+
export * from "langchain/retrievers/ensemble";
4041
export * from "langchain/retrievers/multi_query";
4142
export * from "langchain/retrievers/multi_vector";
4243
export * from "langchain/retrievers/parent_document";

environment_tests/test-exports-cjs/src/entrypoints.js

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const callbacks = require("langchain/callbacks");
3737
const output_parsers = require("langchain/output_parsers");
3838
const retrievers_contextual_compression = require("langchain/retrievers/contextual_compression");
3939
const retrievers_document_compressors = require("langchain/retrievers/document_compressors");
40+
const retrievers_ensemble = require("langchain/retrievers/ensemble");
4041
const retrievers_multi_query = require("langchain/retrievers/multi_query");
4142
const retrievers_multi_vector = require("langchain/retrievers/multi_vector");
4243
const retrievers_parent_document = require("langchain/retrievers/parent_document");

environment_tests/test-exports-esbuild/src/entrypoints.js

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import * as callbacks from "langchain/callbacks";
3737
import * as output_parsers from "langchain/output_parsers";
3838
import * as retrievers_contextual_compression from "langchain/retrievers/contextual_compression";
3939
import * as retrievers_document_compressors from "langchain/retrievers/document_compressors";
40+
import * as retrievers_ensemble from "langchain/retrievers/ensemble";
4041
import * as retrievers_multi_query from "langchain/retrievers/multi_query";
4142
import * as retrievers_multi_vector from "langchain/retrievers/multi_vector";
4243
import * as retrievers_parent_document from "langchain/retrievers/parent_document";

environment_tests/test-exports-esm/src/entrypoints.js

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import * as callbacks from "langchain/callbacks";
3737
import * as output_parsers from "langchain/output_parsers";
3838
import * as retrievers_contextual_compression from "langchain/retrievers/contextual_compression";
3939
import * as retrievers_document_compressors from "langchain/retrievers/document_compressors";
40+
import * as retrievers_ensemble from "langchain/retrievers/ensemble";
4041
import * as retrievers_multi_query from "langchain/retrievers/multi_query";
4142
import * as retrievers_multi_vector from "langchain/retrievers/multi_vector";
4243
import * as retrievers_parent_document from "langchain/retrievers/parent_document";

environment_tests/test-exports-vercel/src/entrypoints.js

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export * from "langchain/callbacks";
3737
export * from "langchain/output_parsers";
3838
export * from "langchain/retrievers/contextual_compression";
3939
export * from "langchain/retrievers/document_compressors";
40+
export * from "langchain/retrievers/ensemble";
4041
export * from "langchain/retrievers/multi_query";
4142
export * from "langchain/retrievers/multi_vector";
4243
export * from "langchain/retrievers/parent_document";

environment_tests/test-exports-vite/src/entrypoints.js

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export * from "langchain/callbacks";
3737
export * from "langchain/output_parsers";
3838
export * from "langchain/retrievers/contextual_compression";
3939
export * from "langchain/retrievers/document_compressors";
40+
export * from "langchain/retrievers/ensemble";
4041
export * from "langchain/retrievers/multi_query";
4142
export * from "langchain/retrievers/multi_vector";
4243
export * from "langchain/retrievers/parent_document";
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import { EnsembleRetriever } from "langchain/retrievers/ensemble";
2+
import { MemoryVectorStore } from "langchain/vectorstores/memory";
3+
import { OpenAIEmbeddings } from "@langchain/openai";
4+
import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers";
5+
import { Document } from "@langchain/core/documents";
6+
7+
class SimpleCustomRetriever extends BaseRetriever {
8+
lc_namespace = [];
9+
10+
documents: Document[];
11+
12+
constructor(fields: { documents: Document[] } & BaseRetrieverInput) {
13+
super(fields);
14+
this.documents = fields.documents;
15+
}
16+
17+
async _getRelevantDocuments(query: string): Promise<Document[]> {
18+
return this.documents.filter((document) =>
19+
document.pageContent.includes(query)
20+
);
21+
}
22+
}
23+
24+
const docs1 = [
25+
new Document({ pageContent: "I like apples", metadata: { source: 1 } }),
26+
new Document({ pageContent: "I like oranges", metadata: { source: 1 } }),
27+
new Document({
28+
pageContent: "apples and oranges are fruits",
29+
metadata: { source: 1 },
30+
}),
31+
];
32+
33+
const keywordRetriever = new SimpleCustomRetriever({ documents: docs1 });
34+
35+
const docs2 = [
36+
new Document({ pageContent: "You like apples", metadata: { source: 2 } }),
37+
new Document({ pageContent: "You like oranges", metadata: { source: 2 } }),
38+
];
39+
40+
const vectorstore = await MemoryVectorStore.fromDocuments(
41+
docs2,
42+
new OpenAIEmbeddings()
43+
);
44+
45+
const vectorstoreRetriever = vectorstore.asRetriever();
46+
47+
const retriever = new EnsembleRetriever({
48+
retrievers: [vectorstoreRetriever, keywordRetriever],
49+
weights: [0.5, 0.5],
50+
});
51+
52+
const query = "apples";
53+
const retrievedDocs = await retriever.invoke(query);
54+
55+
console.log(retrievedDocs);
56+
57+
/*
58+
[
59+
Document { pageContent: 'You like apples', metadata: { source: 2 } },
60+
Document { pageContent: 'I like apples', metadata: { source: 1 } },
61+
Document { pageContent: 'You like oranges', metadata: { source: 2 } },
62+
Document {
63+
pageContent: 'apples and oranges are fruits',
64+
metadata: { source: 1 }
65+
}
66+
]
67+
*/

langchain/.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,10 @@ retrievers/document_compressors.cjs
358358
retrievers/document_compressors.js
359359
retrievers/document_compressors.d.ts
360360
retrievers/document_compressors.d.cts
361+
retrievers/ensemble.cjs
362+
retrievers/ensemble.js
363+
retrievers/ensemble.d.ts
364+
retrievers/ensemble.d.cts
361365
retrievers/multi_query.cjs
362366
retrievers/multi_query.js
363367
retrievers/multi_query.d.ts

langchain/langchain.config.js

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ export const config = {
143143
// retrievers
144144
"retrievers/contextual_compression": "retrievers/contextual_compression",
145145
"retrievers/document_compressors": "retrievers/document_compressors/index",
146+
"retrievers/ensemble": "retrievers/ensemble",
146147
"retrievers/multi_query": "retrievers/multi_query",
147148
"retrievers/multi_vector": "retrievers/multi_vector",
148149
"retrievers/parent_document": "retrievers/parent_document",

langchain/package.json

+13
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,10 @@
370370
"retrievers/document_compressors.js",
371371
"retrievers/document_compressors.d.ts",
372372
"retrievers/document_compressors.d.cts",
373+
"retrievers/ensemble.cjs",
374+
"retrievers/ensemble.js",
375+
"retrievers/ensemble.d.ts",
376+
"retrievers/ensemble.d.cts",
373377
"retrievers/multi_query.cjs",
374378
"retrievers/multi_query.js",
375379
"retrievers/multi_query.d.ts",
@@ -1725,6 +1729,15 @@
17251729
"import": "./retrievers/document_compressors.js",
17261730
"require": "./retrievers/document_compressors.cjs"
17271731
},
1732+
"./retrievers/ensemble": {
1733+
"types": {
1734+
"import": "./retrievers/ensemble.d.ts",
1735+
"require": "./retrievers/ensemble.d.cts",
1736+
"default": "./retrievers/ensemble.d.ts"
1737+
},
1738+
"import": "./retrievers/ensemble.js",
1739+
"require": "./retrievers/ensemble.cjs"
1740+
},
17281741
"./retrievers/multi_query": {
17291742
"types": {
17301743
"import": "./retrievers/multi_query.d.ts",

langchain/src/load/import_map.ts

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ export * as callbacks from "../callbacks/index.js";
3333
export * as output_parsers from "../output_parsers/index.js";
3434
export * as retrievers__contextual_compression from "../retrievers/contextual_compression.js";
3535
export * as retrievers__document_compressors from "../retrievers/document_compressors/index.js";
36+
export * as retrievers__ensemble from "../retrievers/ensemble.js";
3637
export * as retrievers__multi_query from "../retrievers/multi_query.js";
3738
export * as retrievers__multi_vector from "../retrievers/multi_vector.js";
3839
export * as retrievers__parent_document from "../retrievers/parent_document.js";

langchain/src/retrievers/ensemble.ts

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import { BaseRetriever, BaseRetrieverInput } from "@langchain/core/retrievers";
2+
import { Document, DocumentInterface } from "@langchain/core/documents";
3+
import { CallbackManagerForRetrieverRun } from "@langchain/core/callbacks/manager";
4+
5+
export interface EnsembleRetrieverInput extends BaseRetrieverInput {
6+
/** A list of retrievers to ensemble. */
7+
retrievers: BaseRetriever[];
8+
/**
9+
* A list of weights corresponding to the retrievers. Defaults to equal
10+
* weighting for all retrievers.
11+
*/
12+
weights?: number[];
13+
/**
14+
* A constant added to the rank, controlling the balance between the importance
15+
* of high-ranked items and the consideration given to lower-ranked items.
16+
* Default is 60.
17+
*/
18+
c?: number;
19+
}
20+
21+
/**
22+
* Ensemble retriever that aggregates and orders the results of
23+
* multiple retrievers by using weighted Reciprocal Rank Fusion.
24+
*/
25+
export class EnsembleRetriever extends BaseRetriever {
26+
static lc_name() {
27+
return "EnsembleRetriever";
28+
}
29+
30+
lc_namespace = ["langchain", "retrievers", "ensemble_retriever"];
31+
32+
retrievers: BaseRetriever[];
33+
34+
weights: number[];
35+
36+
c = 60;
37+
38+
constructor(args: EnsembleRetrieverInput) {
39+
super(args);
40+
this.retrievers = args.retrievers;
41+
this.weights =
42+
args.weights ||
43+
new Array(args.retrievers.length).fill(1 / args.retrievers.length);
44+
this.c = args.c || 60;
45+
}
46+
47+
async _getRelevantDocuments(
48+
query: string,
49+
runManager?: CallbackManagerForRetrieverRun
50+
) {
51+
return this._rankFusion(query, runManager);
52+
}
53+
54+
async _rankFusion(
55+
query: string,
56+
runManager?: CallbackManagerForRetrieverRun
57+
) {
58+
const retrieverDocs = await Promise.all(
59+
this.retrievers.map((retriever, i) =>
60+
retriever.invoke(query, {
61+
callbacks: runManager?.getChild(`retriever_${i + 1}`),
62+
})
63+
)
64+
);
65+
66+
const fusedDocs = await this._weightedReciprocalRank(retrieverDocs);
67+
return fusedDocs;
68+
}
69+
70+
async _weightedReciprocalRank(docList: DocumentInterface[][]) {
71+
if (docList.length !== this.weights.length) {
72+
throw new Error(
73+
"Number of retrieved document lists must be equal to the number of weights."
74+
);
75+
}
76+
77+
const rrfScoreDict = docList.reduce(
78+
(rffScore: Record<string, number>, retrieverDoc, idx) => {
79+
let rank = 1;
80+
const weight = this.weights[idx];
81+
while (rank <= retrieverDoc.length) {
82+
const { pageContent } = retrieverDoc[rank - 1];
83+
if (!rffScore[pageContent]) {
84+
// eslint-disable-next-line no-param-reassign
85+
rffScore[pageContent] = 0;
86+
}
87+
// eslint-disable-next-line no-param-reassign
88+
rffScore[pageContent] += weight / (rank + this.c);
89+
rank += 1;
90+
}
91+
92+
return rffScore;
93+
},
94+
{}
95+
);
96+
97+
const uniqueDocs = this._uniqueUnion(docList.flat());
98+
const sortedDocs = Array.from(uniqueDocs).sort(
99+
(a, b) => rrfScoreDict[b.pageContent] - rrfScoreDict[a.pageContent]
100+
);
101+
102+
return sortedDocs;
103+
}
104+
105+
private _uniqueUnion(documents: Document[]): Document[] {
106+
const documentSet = new Set();
107+
const result = [];
108+
109+
for (const doc of documents) {
110+
const key = doc.pageContent;
111+
if (!documentSet.has(key)) {
112+
documentSet.add(key);
113+
result.push(doc);
114+
}
115+
}
116+
117+
return result;
118+
}
119+
}

0 commit comments

Comments
 (0)