agsamantha/node_modules/langchain/dist/retrievers/ensemble.js
2024-10-02 15:15:21 -05:00

87 lines
3.1 KiB
JavaScript

import { BaseRetriever } from "@langchain/core/retrievers";
/**
* Ensemble retriever that aggregates and orders the results of
* multiple retrievers by using weighted Reciprocal Rank Fusion.
*/
export class EnsembleRetriever extends BaseRetriever {
static lc_name() {
return "EnsembleRetriever";
}
constructor(args) {
super(args);
Object.defineProperty(this, "lc_namespace", {
enumerable: true,
configurable: true,
writable: true,
value: ["langchain", "retrievers", "ensemble_retriever"]
});
Object.defineProperty(this, "retrievers", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "weights", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "c", {
enumerable: true,
configurable: true,
writable: true,
value: 60
});
this.retrievers = args.retrievers;
this.weights =
args.weights ||
new Array(args.retrievers.length).fill(1 / args.retrievers.length);
this.c = args.c || 60;
}
async _getRelevantDocuments(query, runManager) {
return this._rankFusion(query, runManager);
}
async _rankFusion(query, runManager) {
const retrieverDocs = await Promise.all(this.retrievers.map((retriever, i) => retriever.invoke(query, {
callbacks: runManager?.getChild(`retriever_${i + 1}`),
})));
const fusedDocs = await this._weightedReciprocalRank(retrieverDocs);
return fusedDocs;
}
async _weightedReciprocalRank(docList) {
if (docList.length !== this.weights.length) {
throw new Error("Number of retrieved document lists must be equal to the number of weights.");
}
const rrfScoreDict = docList.reduce((rffScore, retrieverDoc, idx) => {
let rank = 1;
const weight = this.weights[idx];
while (rank <= retrieverDoc.length) {
const { pageContent } = retrieverDoc[rank - 1];
if (!rffScore[pageContent]) {
// eslint-disable-next-line no-param-reassign
rffScore[pageContent] = 0;
}
// eslint-disable-next-line no-param-reassign
rffScore[pageContent] += weight / (rank + this.c);
rank += 1;
}
return rffScore;
}, {});
const uniqueDocs = this._uniqueUnion(docList.flat());
const sortedDocs = Array.from(uniqueDocs).sort((a, b) => rrfScoreDict[b.pageContent] - rrfScoreDict[a.pageContent]);
return sortedDocs;
}
_uniqueUnion(documents) {
const documentSet = new Set();
const result = [];
for (const doc of documents) {
const key = doc.pageContent;
if (!documentSet.has(key)) {
documentSet.add(key);
result.push(doc);
}
}
return result;
}
}