agsamantha/node_modules/langchain/dist/chains/router/multi_retrieval_qa.cjs

154 lines
6.7 KiB
JavaScript
Raw Normal View History

2024-10-02 15:15:21 -05:00
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.MultiRetrievalQAChain = void 0;
const zod_1 = require("zod");
const prompts_1 = require("@langchain/core/prompts");
const multi_route_js_1 = require("./multi_route.cjs");
const llm_router_js_1 = require("./llm_router.cjs");
const conversation_js_1 = require("../../chains/conversation.cjs");
const multi_retrieval_prompt_js_1 = require("./multi_retrieval_prompt.cjs");
const utils_js_1 = require("./utils.cjs");
const retrieval_qa_js_1 = require("../../chains/retrieval_qa.cjs");
const router_js_1 = require("../../output_parsers/router.cjs");
/**
* A class that represents a multi-retrieval question answering chain in
* the LangChain framework. It extends the MultiRouteChain class and
* provides additional functionality specific to multi-retrieval QA
* chains.
* @example
* ```typescript
* const multiRetrievalQAChain = MultiRetrievalQAChain.fromLLMAndRetrievers(
* new ChatOpenAI(),
* {
* retrieverNames: ["aqua teen", "mst3k", "animaniacs"],
* retrieverDescriptions: [
* "Good for answering questions about Aqua Teen Hunger Force theme song",
* "Good for answering questions about Mystery Science Theater 3000 theme song",
* "Good for answering questions about Animaniacs theme song",
* ],
* retrievers: [
* new MemoryVectorStore().asRetriever(3),
* new MemoryVectorStore().asRetriever(3),
* new MemoryVectorStore().asRetriever(3),
* ],
* retrievalQAChainOpts: {
* returnSourceDocuments: true,
* },
* },
* );
*
* const result = await multiRetrievalQAChain.call({
* input:
* "In the Aqua Teen Hunger Force theme song, who calls himself the mike rula?",
* });
*
* console.log(result.sourceDocuments, result.text);
* ```
*/
class MultiRetrievalQAChain extends multi_route_js_1.MultiRouteChain {
get outputKeys() {
return ["result"];
}
/**
* @deprecated Use `fromRetrieversAndPrompts` instead
*/
static fromRetrievers(llm, retrieverNames, retrieverDescriptions, retrievers, retrieverPrompts, defaults, options) {
return MultiRetrievalQAChain.fromLLMAndRetrievers(llm, {
retrieverNames,
retrieverDescriptions,
retrievers,
retrieverPrompts,
defaults,
multiRetrievalChainOpts: options,
});
}
/**
* A static method that creates an instance of MultiRetrievalQAChain from
* a BaseLanguageModel and a set of retrievers. It takes in optional
* parameters for the retriever names, descriptions, prompts, defaults,
* and additional options. It is an alternative method to fromRetrievers
* and provides more flexibility in configuring the underlying chains.
* @param llm A BaseLanguageModel instance.
* @param retrieverNames An array of retriever names.
* @param retrieverDescriptions An array of retriever descriptions.
* @param retrievers An array of BaseRetrieverInterface instances.
* @param retrieverPrompts An optional array of PromptTemplate instances for the retrievers.
* @param defaults An optional MultiRetrievalDefaults instance.
* @param multiRetrievalChainOpts Additional optional parameters for the multi-retrieval chain.
* @param retrievalQAChainOpts Additional optional parameters for the retrieval QA chain.
* @returns A new instance of MultiRetrievalQAChain.
*/
static fromLLMAndRetrievers(llm, { retrieverNames, retrieverDescriptions, retrievers, retrieverPrompts, defaults, multiRetrievalChainOpts, retrievalQAChainOpts, }) {
const { defaultRetriever, defaultPrompt, defaultChain } = defaults ?? {};
if (defaultPrompt && !defaultRetriever) {
throw new Error("`default_retriever` must be specified if `default_prompt` is \nprovided. Received only `default_prompt`.");
}
const destinations = (0, utils_js_1.zipEntries)(retrieverNames, retrieverDescriptions).map(([name, desc]) => `${name}: ${desc}`);
const structuredOutputParserSchema = zod_1.z.object({
destination: zod_1.z
.string()
.optional()
.describe('name of the question answering system to use or "DEFAULT"'),
next_inputs: zod_1.z
.object({
query: zod_1.z
.string()
.describe("a potentially modified version of the original input"),
})
.describe("input to be fed to the next model"),
});
const outputParser = new router_js_1.RouterOutputParser(structuredOutputParserSchema);
const destinationsStr = destinations.join("\n");
const routerTemplate = (0, prompts_1.interpolateFString)((0, multi_retrieval_prompt_js_1.STRUCTURED_MULTI_RETRIEVAL_ROUTER_TEMPLATE)(outputParser.getFormatInstructions({ interpolationDepth: 4 })), {
destinations: destinationsStr,
});
const routerPrompt = new prompts_1.PromptTemplate({
template: routerTemplate,
inputVariables: ["input"],
outputParser,
});
const routerChain = llm_router_js_1.LLMRouterChain.fromLLM(llm, routerPrompt);
const prompts = retrieverPrompts ?? retrievers.map(() => null);
const destinationChains = (0, utils_js_1.zipEntries)(retrieverNames, retrievers, prompts).reduce((acc, [name, retriever, prompt]) => {
const opt = retrievalQAChainOpts ?? {};
if (prompt) {
opt.prompt = prompt;
}
acc[name] = retrieval_qa_js_1.RetrievalQAChain.fromLLM(llm, retriever, opt);
return acc;
}, {});
let _defaultChain;
if (defaultChain) {
_defaultChain = defaultChain;
}
else if (defaultRetriever) {
_defaultChain = retrieval_qa_js_1.RetrievalQAChain.fromLLM(llm, defaultRetriever, {
...retrievalQAChainOpts,
prompt: defaultPrompt,
});
}
else {
const promptTemplate = conversation_js_1.DEFAULT_TEMPLATE.replace("input", "query");
const prompt = new prompts_1.PromptTemplate({
template: promptTemplate,
inputVariables: ["history", "query"],
});
_defaultChain = new conversation_js_1.ConversationChain({
llm,
prompt,
outputKey: "result",
});
}
return new MultiRetrievalQAChain({
...multiRetrievalChainOpts,
routerChain,
destinationChains,
defaultChain: _defaultChain,
});
}
_chainType() {
return "multi_retrieval_qa_chain";
}
}
exports.MultiRetrievalQAChain = MultiRetrievalQAChain;