import { PromptTemplate } from "@langchain/core/prompts"; import { BaseChain } from "./base.js"; import { LLMChain } from "./llm_chain.js"; /** * Chain that combines documents by stuffing into context. * @augments BaseChain * @augments StuffDocumentsChainInput */ export class StuffDocumentsChain extends BaseChain { static lc_name() { return "StuffDocumentsChain"; } get inputKeys() { return [this.inputKey, ...this.llmChain.inputKeys].filter((key) => key !== this.documentVariableName); } get outputKeys() { return this.llmChain.outputKeys; } constructor(fields) { super(fields); Object.defineProperty(this, "llmChain", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "inputKey", { enumerable: true, configurable: true, writable: true, value: "input_documents" }); Object.defineProperty(this, "documentVariableName", { enumerable: true, configurable: true, writable: true, value: "context" }); this.llmChain = fields.llmChain; this.documentVariableName = fields.documentVariableName ?? this.documentVariableName; this.inputKey = fields.inputKey ?? this.inputKey; } /** @ignore */ _prepInputs(values) { if (!(this.inputKey in values)) { throw new Error(`Document key ${this.inputKey} not found.`); } const { [this.inputKey]: docs, ...rest } = values; const texts = docs.map(({ pageContent }) => pageContent); const text = texts.join("\n\n"); return { ...rest, [this.documentVariableName]: text, }; } /** @ignore */ async _call(values, runManager) { const result = await this.llmChain.call(this._prepInputs(values), runManager?.getChild("combine_documents")); return result; } _chainType() { return "stuff_documents_chain"; } static async deserialize(data) { if (!data.llm_chain) { throw new Error("Missing llm_chain"); } return new StuffDocumentsChain({ llmChain: await LLMChain.deserialize(data.llm_chain), }); } serialize() { return { _type: this._chainType(), llm_chain: this.llmChain.serialize(), }; } } /** * Combine documents by mapping a chain over them, then combining results. * @augments BaseChain * @augments StuffDocumentsChainInput */ export class MapReduceDocumentsChain extends BaseChain { static lc_name() { return "MapReduceDocumentsChain"; } get inputKeys() { return [this.inputKey, ...this.combineDocumentChain.inputKeys]; } get outputKeys() { return this.combineDocumentChain.outputKeys; } constructor(fields) { super(fields); Object.defineProperty(this, "llmChain", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "inputKey", { enumerable: true, configurable: true, writable: true, value: "input_documents" }); Object.defineProperty(this, "documentVariableName", { enumerable: true, configurable: true, writable: true, value: "context" }); Object.defineProperty(this, "returnIntermediateSteps", { enumerable: true, configurable: true, writable: true, value: false }); Object.defineProperty(this, "maxTokens", { enumerable: true, configurable: true, writable: true, value: 3000 }); Object.defineProperty(this, "maxIterations", { enumerable: true, configurable: true, writable: true, value: 10 }); Object.defineProperty(this, "ensureMapStep", { enumerable: true, configurable: true, writable: true, value: false }); Object.defineProperty(this, "combineDocumentChain", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.llmChain = fields.llmChain; this.combineDocumentChain = fields.combineDocumentChain; this.documentVariableName = fields.documentVariableName ?? this.documentVariableName; this.ensureMapStep = fields.ensureMapStep ?? this.ensureMapStep; this.inputKey = fields.inputKey ?? this.inputKey; this.maxTokens = fields.maxTokens ?? this.maxTokens; this.maxIterations = fields.maxIterations ?? this.maxIterations; this.returnIntermediateSteps = fields.returnIntermediateSteps ?? false; } /** @ignore */ async _call(values, runManager) { if (!(this.inputKey in values)) { throw new Error(`Document key ${this.inputKey} not found.`); } const { [this.inputKey]: docs, ...rest } = values; let currentDocs = docs; let intermediateSteps = []; // For each iteration, we'll use the `llmChain` to get a new result for (let i = 0; i < this.maxIterations; i += 1) { const inputs = currentDocs.map((d) => ({ [this.documentVariableName]: d.pageContent, ...rest, })); const canSkipMapStep = i !== 0 || !this.ensureMapStep; if (canSkipMapStep) { // Calculate the total tokens required in the input const formatted = await this.combineDocumentChain.llmChain.prompt.format(this.combineDocumentChain._prepInputs({ [this.combineDocumentChain.inputKey]: currentDocs, ...rest, })); const length = await this.combineDocumentChain.llmChain._getNumTokens(formatted); const withinTokenLimit = length < this.maxTokens; // If we can skip the map step, and we're within the token limit, we don't // need to run the map step, so just break out of the loop. if (withinTokenLimit) { break; } } const results = await this.llmChain.apply(inputs, // If we have a runManager, then we need to create a child for each input // so that we can track the progress of each input. runManager ? Array.from({ length: inputs.length }, (_, i) => runManager.getChild(`map_${i + 1}`)) : undefined); const { outputKey } = this.llmChain; // If the flag is set, then concat that to the intermediate steps if (this.returnIntermediateSteps) { intermediateSteps = intermediateSteps.concat(results.map((r) => r[outputKey])); } currentDocs = results.map((r) => ({ pageContent: r[outputKey], metadata: {}, })); } // Now, with the final result of all the inputs from the `llmChain`, we can // run the `combineDocumentChain` over them. const newInputs = { [this.combineDocumentChain.inputKey]: currentDocs, ...rest, }; const result = await this.combineDocumentChain.call(newInputs, runManager?.getChild("combine_documents")); // Return the intermediate steps results if the flag is set if (this.returnIntermediateSteps) { return { ...result, intermediateSteps }; } return result; } _chainType() { return "map_reduce_documents_chain"; } static async deserialize(data) { if (!data.llm_chain) { throw new Error("Missing llm_chain"); } if (!data.combine_document_chain) { throw new Error("Missing combine_document_chain"); } return new MapReduceDocumentsChain({ llmChain: await LLMChain.deserialize(data.llm_chain), combineDocumentChain: await StuffDocumentsChain.deserialize(data.combine_document_chain), }); } serialize() { return { _type: this._chainType(), llm_chain: this.llmChain.serialize(), combine_document_chain: this.combineDocumentChain.serialize(), }; } } /** * Combine documents by doing a first pass and then refining on more documents. * @augments BaseChain * @augments RefineDocumentsChainInput */ export class RefineDocumentsChain extends BaseChain { static lc_name() { return "RefineDocumentsChain"; } get defaultDocumentPrompt() { return new PromptTemplate({ inputVariables: ["page_content"], template: "{page_content}", }); } get inputKeys() { return [ ...new Set([ this.inputKey, ...this.llmChain.inputKeys, ...this.refineLLMChain.inputKeys, ]), ].filter((key) => key !== this.documentVariableName && key !== this.initialResponseName); } get outputKeys() { return [this.outputKey]; } constructor(fields) { super(fields); Object.defineProperty(this, "llmChain", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "inputKey", { enumerable: true, configurable: true, writable: true, value: "input_documents" }); Object.defineProperty(this, "outputKey", { enumerable: true, configurable: true, writable: true, value: "output_text" }); Object.defineProperty(this, "documentVariableName", { enumerable: true, configurable: true, writable: true, value: "context" }); Object.defineProperty(this, "initialResponseName", { enumerable: true, configurable: true, writable: true, value: "existing_answer" }); Object.defineProperty(this, "refineLLMChain", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "documentPrompt", { enumerable: true, configurable: true, writable: true, value: this.defaultDocumentPrompt }); this.llmChain = fields.llmChain; this.refineLLMChain = fields.refineLLMChain; this.documentVariableName = fields.documentVariableName ?? this.documentVariableName; this.inputKey = fields.inputKey ?? this.inputKey; this.outputKey = fields.outputKey ?? this.outputKey; this.documentPrompt = fields.documentPrompt ?? this.documentPrompt; this.initialResponseName = fields.initialResponseName ?? this.initialResponseName; } /** @ignore */ async _constructInitialInputs(doc, rest) { const baseInfo = { page_content: doc.pageContent, ...doc.metadata, }; const documentInfo = {}; this.documentPrompt.inputVariables.forEach((value) => { documentInfo[value] = baseInfo[value]; }); const baseInputs = { [this.documentVariableName]: await this.documentPrompt.format({ ...documentInfo, }), }; const inputs = { ...baseInputs, ...rest }; return inputs; } /** @ignore */ async _constructRefineInputs(doc, res) { const baseInfo = { page_content: doc.pageContent, ...doc.metadata, }; const documentInfo = {}; this.documentPrompt.inputVariables.forEach((value) => { documentInfo[value] = baseInfo[value]; }); const baseInputs = { [this.documentVariableName]: await this.documentPrompt.format({ ...documentInfo, }), }; const inputs = { [this.initialResponseName]: res, ...baseInputs }; return inputs; } /** @ignore */ async _call(values, runManager) { if (!(this.inputKey in values)) { throw new Error(`Document key ${this.inputKey} not found.`); } const { [this.inputKey]: docs, ...rest } = values; const currentDocs = docs; const initialInputs = await this._constructInitialInputs(currentDocs[0], rest); let res = await this.llmChain.predict({ ...initialInputs }, runManager?.getChild("answer")); const refineSteps = [res]; for (let i = 1; i < currentDocs.length; i += 1) { const refineInputs = await this._constructRefineInputs(currentDocs[i], res); const inputs = { ...refineInputs, ...rest }; res = await this.refineLLMChain.predict({ ...inputs }, runManager?.getChild("refine")); refineSteps.push(res); } return { [this.outputKey]: res }; } _chainType() { return "refine_documents_chain"; } static async deserialize(data) { const SerializedLLMChain = data.llm_chain; if (!SerializedLLMChain) { throw new Error("Missing llm_chain"); } const SerializedRefineDocumentChain = data.refine_llm_chain; if (!SerializedRefineDocumentChain) { throw new Error("Missing refine_llm_chain"); } return new RefineDocumentsChain({ llmChain: await LLMChain.deserialize(SerializedLLMChain), refineLLMChain: await LLMChain.deserialize(SerializedRefineDocumentChain), }); } serialize() { return { _type: this._chainType(), llm_chain: this.llmChain.serialize(), refine_llm_chain: this.refineLLMChain.serialize(), }; } }