88 lines
3.1 KiB
JavaScript
88 lines
3.1 KiB
JavaScript
|
import { BaseRetriever } from "@langchain/core/retrievers";
|
||
|
/**
|
||
|
* Ensemble retriever that aggregates and orders the results of
|
||
|
* multiple retrievers by using weighted Reciprocal Rank Fusion.
|
||
|
*/
|
||
|
export class EnsembleRetriever extends BaseRetriever {
|
||
|
static lc_name() {
|
||
|
return "EnsembleRetriever";
|
||
|
}
|
||
|
constructor(args) {
|
||
|
super(args);
|
||
|
Object.defineProperty(this, "lc_namespace", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: ["langchain", "retrievers", "ensemble_retriever"]
|
||
|
});
|
||
|
Object.defineProperty(this, "retrievers", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "weights", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "c", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: 60
|
||
|
});
|
||
|
this.retrievers = args.retrievers;
|
||
|
this.weights =
|
||
|
args.weights ||
|
||
|
new Array(args.retrievers.length).fill(1 / args.retrievers.length);
|
||
|
this.c = args.c || 60;
|
||
|
}
|
||
|
async _getRelevantDocuments(query, runManager) {
|
||
|
return this._rankFusion(query, runManager);
|
||
|
}
|
||
|
async _rankFusion(query, runManager) {
|
||
|
const retrieverDocs = await Promise.all(this.retrievers.map((retriever, i) => retriever.invoke(query, {
|
||
|
callbacks: runManager?.getChild(`retriever_${i + 1}`),
|
||
|
})));
|
||
|
const fusedDocs = await this._weightedReciprocalRank(retrieverDocs);
|
||
|
return fusedDocs;
|
||
|
}
|
||
|
async _weightedReciprocalRank(docList) {
|
||
|
if (docList.length !== this.weights.length) {
|
||
|
throw new Error("Number of retrieved document lists must be equal to the number of weights.");
|
||
|
}
|
||
|
const rrfScoreDict = docList.reduce((rffScore, retrieverDoc, idx) => {
|
||
|
let rank = 1;
|
||
|
const weight = this.weights[idx];
|
||
|
while (rank <= retrieverDoc.length) {
|
||
|
const { pageContent } = retrieverDoc[rank - 1];
|
||
|
if (!rffScore[pageContent]) {
|
||
|
// eslint-disable-next-line no-param-reassign
|
||
|
rffScore[pageContent] = 0;
|
||
|
}
|
||
|
// eslint-disable-next-line no-param-reassign
|
||
|
rffScore[pageContent] += weight / (rank + this.c);
|
||
|
rank += 1;
|
||
|
}
|
||
|
return rffScore;
|
||
|
}, {});
|
||
|
const uniqueDocs = this._uniqueUnion(docList.flat());
|
||
|
const sortedDocs = Array.from(uniqueDocs).sort((a, b) => rrfScoreDict[b.pageContent] - rrfScoreDict[a.pageContent]);
|
||
|
return sortedDocs;
|
||
|
}
|
||
|
_uniqueUnion(documents) {
|
||
|
const documentSet = new Set();
|
||
|
const result = [];
|
||
|
for (const doc of documents) {
|
||
|
const key = doc.pageContent;
|
||
|
if (!documentSet.has(key)) {
|
||
|
documentSet.add(key);
|
||
|
result.push(doc);
|
||
|
}
|
||
|
}
|
||
|
return result;
|
||
|
}
|
||
|
}
|