150 lines
6.4 KiB
JavaScript
150 lines
6.4 KiB
JavaScript
|
import { z } from "zod";
|
||
|
import { interpolateFString, PromptTemplate } from "@langchain/core/prompts";
|
||
|
import { MultiRouteChain } from "./multi_route.js";
|
||
|
import { LLMRouterChain } from "./llm_router.js";
|
||
|
import { ConversationChain, DEFAULT_TEMPLATE, } from "../../chains/conversation.js";
|
||
|
import { STRUCTURED_MULTI_RETRIEVAL_ROUTER_TEMPLATE } from "./multi_retrieval_prompt.js";
|
||
|
import { zipEntries } from "./utils.js";
|
||
|
import { RetrievalQAChain, } from "../../chains/retrieval_qa.js";
|
||
|
import { RouterOutputParser } from "../../output_parsers/router.js";
|
||
|
/**
|
||
|
* A class that represents a multi-retrieval question answering chain in
|
||
|
* the LangChain framework. It extends the MultiRouteChain class and
|
||
|
* provides additional functionality specific to multi-retrieval QA
|
||
|
* chains.
|
||
|
* @example
|
||
|
* ```typescript
|
||
|
* const multiRetrievalQAChain = MultiRetrievalQAChain.fromLLMAndRetrievers(
|
||
|
* new ChatOpenAI(),
|
||
|
* {
|
||
|
* retrieverNames: ["aqua teen", "mst3k", "animaniacs"],
|
||
|
* retrieverDescriptions: [
|
||
|
* "Good for answering questions about Aqua Teen Hunger Force theme song",
|
||
|
* "Good for answering questions about Mystery Science Theater 3000 theme song",
|
||
|
* "Good for answering questions about Animaniacs theme song",
|
||
|
* ],
|
||
|
* retrievers: [
|
||
|
* new MemoryVectorStore().asRetriever(3),
|
||
|
* new MemoryVectorStore().asRetriever(3),
|
||
|
* new MemoryVectorStore().asRetriever(3),
|
||
|
* ],
|
||
|
* retrievalQAChainOpts: {
|
||
|
* returnSourceDocuments: true,
|
||
|
* },
|
||
|
* },
|
||
|
* );
|
||
|
*
|
||
|
* const result = await multiRetrievalQAChain.call({
|
||
|
* input:
|
||
|
* "In the Aqua Teen Hunger Force theme song, who calls himself the mike rula?",
|
||
|
* });
|
||
|
*
|
||
|
* console.log(result.sourceDocuments, result.text);
|
||
|
* ```
|
||
|
*/
|
||
|
export class MultiRetrievalQAChain extends MultiRouteChain {
|
||
|
get outputKeys() {
|
||
|
return ["result"];
|
||
|
}
|
||
|
/**
|
||
|
* @deprecated Use `fromRetrieversAndPrompts` instead
|
||
|
*/
|
||
|
static fromRetrievers(llm, retrieverNames, retrieverDescriptions, retrievers, retrieverPrompts, defaults, options) {
|
||
|
return MultiRetrievalQAChain.fromLLMAndRetrievers(llm, {
|
||
|
retrieverNames,
|
||
|
retrieverDescriptions,
|
||
|
retrievers,
|
||
|
retrieverPrompts,
|
||
|
defaults,
|
||
|
multiRetrievalChainOpts: options,
|
||
|
});
|
||
|
}
|
||
|
/**
|
||
|
* A static method that creates an instance of MultiRetrievalQAChain from
|
||
|
* a BaseLanguageModel and a set of retrievers. It takes in optional
|
||
|
* parameters for the retriever names, descriptions, prompts, defaults,
|
||
|
* and additional options. It is an alternative method to fromRetrievers
|
||
|
* and provides more flexibility in configuring the underlying chains.
|
||
|
* @param llm A BaseLanguageModel instance.
|
||
|
* @param retrieverNames An array of retriever names.
|
||
|
* @param retrieverDescriptions An array of retriever descriptions.
|
||
|
* @param retrievers An array of BaseRetrieverInterface instances.
|
||
|
* @param retrieverPrompts An optional array of PromptTemplate instances for the retrievers.
|
||
|
* @param defaults An optional MultiRetrievalDefaults instance.
|
||
|
* @param multiRetrievalChainOpts Additional optional parameters for the multi-retrieval chain.
|
||
|
* @param retrievalQAChainOpts Additional optional parameters for the retrieval QA chain.
|
||
|
* @returns A new instance of MultiRetrievalQAChain.
|
||
|
*/
|
||
|
static fromLLMAndRetrievers(llm, { retrieverNames, retrieverDescriptions, retrievers, retrieverPrompts, defaults, multiRetrievalChainOpts, retrievalQAChainOpts, }) {
|
||
|
const { defaultRetriever, defaultPrompt, defaultChain } = defaults ?? {};
|
||
|
if (defaultPrompt && !defaultRetriever) {
|
||
|
throw new Error("`default_retriever` must be specified if `default_prompt` is \nprovided. Received only `default_prompt`.");
|
||
|
}
|
||
|
const destinations = zipEntries(retrieverNames, retrieverDescriptions).map(([name, desc]) => `${name}: ${desc}`);
|
||
|
const structuredOutputParserSchema = z.object({
|
||
|
destination: z
|
||
|
.string()
|
||
|
.optional()
|
||
|
.describe('name of the question answering system to use or "DEFAULT"'),
|
||
|
next_inputs: z
|
||
|
.object({
|
||
|
query: z
|
||
|
.string()
|
||
|
.describe("a potentially modified version of the original input"),
|
||
|
})
|
||
|
.describe("input to be fed to the next model"),
|
||
|
});
|
||
|
const outputParser = new RouterOutputParser(structuredOutputParserSchema);
|
||
|
const destinationsStr = destinations.join("\n");
|
||
|
const routerTemplate = interpolateFString(STRUCTURED_MULTI_RETRIEVAL_ROUTER_TEMPLATE(outputParser.getFormatInstructions({ interpolationDepth: 4 })), {
|
||
|
destinations: destinationsStr,
|
||
|
});
|
||
|
const routerPrompt = new PromptTemplate({
|
||
|
template: routerTemplate,
|
||
|
inputVariables: ["input"],
|
||
|
outputParser,
|
||
|
});
|
||
|
const routerChain = LLMRouterChain.fromLLM(llm, routerPrompt);
|
||
|
const prompts = retrieverPrompts ?? retrievers.map(() => null);
|
||
|
const destinationChains = zipEntries(retrieverNames, retrievers, prompts).reduce((acc, [name, retriever, prompt]) => {
|
||
|
const opt = retrievalQAChainOpts ?? {};
|
||
|
if (prompt) {
|
||
|
opt.prompt = prompt;
|
||
|
}
|
||
|
acc[name] = RetrievalQAChain.fromLLM(llm, retriever, opt);
|
||
|
return acc;
|
||
|
}, {});
|
||
|
let _defaultChain;
|
||
|
if (defaultChain) {
|
||
|
_defaultChain = defaultChain;
|
||
|
}
|
||
|
else if (defaultRetriever) {
|
||
|
_defaultChain = RetrievalQAChain.fromLLM(llm, defaultRetriever, {
|
||
|
...retrievalQAChainOpts,
|
||
|
prompt: defaultPrompt,
|
||
|
});
|
||
|
}
|
||
|
else {
|
||
|
const promptTemplate = DEFAULT_TEMPLATE.replace("input", "query");
|
||
|
const prompt = new PromptTemplate({
|
||
|
template: promptTemplate,
|
||
|
inputVariables: ["history", "query"],
|
||
|
});
|
||
|
_defaultChain = new ConversationChain({
|
||
|
llm,
|
||
|
prompt,
|
||
|
outputKey: "result",
|
||
|
});
|
||
|
}
|
||
|
return new MultiRetrievalQAChain({
|
||
|
...multiRetrievalChainOpts,
|
||
|
routerChain,
|
||
|
destinationChains,
|
||
|
defaultChain: _defaultChain,
|
||
|
});
|
||
|
}
|
||
|
_chainType() {
|
||
|
return "multi_retrieval_qa_chain";
|
||
|
}
|
||
|
}
|