395 lines
14 KiB
JavaScript
395 lines
14 KiB
JavaScript
import { PromptTemplate } from "@langchain/core/prompts";
|
|
import { BaseChain } from "./base.js";
|
|
import { LLMChain } from "./llm_chain.js";
|
|
/**
|
|
* Chain that combines documents by stuffing into context.
|
|
* @augments BaseChain
|
|
* @augments StuffDocumentsChainInput
|
|
*/
|
|
export class StuffDocumentsChain extends BaseChain {
|
|
static lc_name() {
|
|
return "StuffDocumentsChain";
|
|
}
|
|
get inputKeys() {
|
|
return [this.inputKey, ...this.llmChain.inputKeys].filter((key) => key !== this.documentVariableName);
|
|
}
|
|
get outputKeys() {
|
|
return this.llmChain.outputKeys;
|
|
}
|
|
constructor(fields) {
|
|
super(fields);
|
|
Object.defineProperty(this, "llmChain", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "inputKey", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "input_documents"
|
|
});
|
|
Object.defineProperty(this, "documentVariableName", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "context"
|
|
});
|
|
this.llmChain = fields.llmChain;
|
|
this.documentVariableName =
|
|
fields.documentVariableName ?? this.documentVariableName;
|
|
this.inputKey = fields.inputKey ?? this.inputKey;
|
|
}
|
|
/** @ignore */
|
|
_prepInputs(values) {
|
|
if (!(this.inputKey in values)) {
|
|
throw new Error(`Document key ${this.inputKey} not found.`);
|
|
}
|
|
const { [this.inputKey]: docs, ...rest } = values;
|
|
const texts = docs.map(({ pageContent }) => pageContent);
|
|
const text = texts.join("\n\n");
|
|
return {
|
|
...rest,
|
|
[this.documentVariableName]: text,
|
|
};
|
|
}
|
|
/** @ignore */
|
|
async _call(values, runManager) {
|
|
const result = await this.llmChain.call(this._prepInputs(values), runManager?.getChild("combine_documents"));
|
|
return result;
|
|
}
|
|
_chainType() {
|
|
return "stuff_documents_chain";
|
|
}
|
|
static async deserialize(data) {
|
|
if (!data.llm_chain) {
|
|
throw new Error("Missing llm_chain");
|
|
}
|
|
return new StuffDocumentsChain({
|
|
llmChain: await LLMChain.deserialize(data.llm_chain),
|
|
});
|
|
}
|
|
serialize() {
|
|
return {
|
|
_type: this._chainType(),
|
|
llm_chain: this.llmChain.serialize(),
|
|
};
|
|
}
|
|
}
|
|
/**
|
|
* Combine documents by mapping a chain over them, then combining results.
|
|
* @augments BaseChain
|
|
* @augments StuffDocumentsChainInput
|
|
*/
|
|
export class MapReduceDocumentsChain extends BaseChain {
|
|
static lc_name() {
|
|
return "MapReduceDocumentsChain";
|
|
}
|
|
get inputKeys() {
|
|
return [this.inputKey, ...this.combineDocumentChain.inputKeys];
|
|
}
|
|
get outputKeys() {
|
|
return this.combineDocumentChain.outputKeys;
|
|
}
|
|
constructor(fields) {
|
|
super(fields);
|
|
Object.defineProperty(this, "llmChain", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "inputKey", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "input_documents"
|
|
});
|
|
Object.defineProperty(this, "documentVariableName", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "context"
|
|
});
|
|
Object.defineProperty(this, "returnIntermediateSteps", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: false
|
|
});
|
|
Object.defineProperty(this, "maxTokens", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: 3000
|
|
});
|
|
Object.defineProperty(this, "maxIterations", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: 10
|
|
});
|
|
Object.defineProperty(this, "ensureMapStep", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: false
|
|
});
|
|
Object.defineProperty(this, "combineDocumentChain", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
this.llmChain = fields.llmChain;
|
|
this.combineDocumentChain = fields.combineDocumentChain;
|
|
this.documentVariableName =
|
|
fields.documentVariableName ?? this.documentVariableName;
|
|
this.ensureMapStep = fields.ensureMapStep ?? this.ensureMapStep;
|
|
this.inputKey = fields.inputKey ?? this.inputKey;
|
|
this.maxTokens = fields.maxTokens ?? this.maxTokens;
|
|
this.maxIterations = fields.maxIterations ?? this.maxIterations;
|
|
this.returnIntermediateSteps = fields.returnIntermediateSteps ?? false;
|
|
}
|
|
/** @ignore */
|
|
async _call(values, runManager) {
|
|
if (!(this.inputKey in values)) {
|
|
throw new Error(`Document key ${this.inputKey} not found.`);
|
|
}
|
|
const { [this.inputKey]: docs, ...rest } = values;
|
|
let currentDocs = docs;
|
|
let intermediateSteps = [];
|
|
// For each iteration, we'll use the `llmChain` to get a new result
|
|
for (let i = 0; i < this.maxIterations; i += 1) {
|
|
const inputs = currentDocs.map((d) => ({
|
|
[this.documentVariableName]: d.pageContent,
|
|
...rest,
|
|
}));
|
|
const canSkipMapStep = i !== 0 || !this.ensureMapStep;
|
|
if (canSkipMapStep) {
|
|
// Calculate the total tokens required in the input
|
|
const formatted = await this.combineDocumentChain.llmChain.prompt.format(this.combineDocumentChain._prepInputs({
|
|
[this.combineDocumentChain.inputKey]: currentDocs,
|
|
...rest,
|
|
}));
|
|
const length = await this.combineDocumentChain.llmChain._getNumTokens(formatted);
|
|
const withinTokenLimit = length < this.maxTokens;
|
|
// If we can skip the map step, and we're within the token limit, we don't
|
|
// need to run the map step, so just break out of the loop.
|
|
if (withinTokenLimit) {
|
|
break;
|
|
}
|
|
}
|
|
const results = await this.llmChain.apply(inputs,
|
|
// If we have a runManager, then we need to create a child for each input
|
|
// so that we can track the progress of each input.
|
|
runManager
|
|
? Array.from({ length: inputs.length }, (_, i) => runManager.getChild(`map_${i + 1}`))
|
|
: undefined);
|
|
const { outputKey } = this.llmChain;
|
|
// If the flag is set, then concat that to the intermediate steps
|
|
if (this.returnIntermediateSteps) {
|
|
intermediateSteps = intermediateSteps.concat(results.map((r) => r[outputKey]));
|
|
}
|
|
currentDocs = results.map((r) => ({
|
|
pageContent: r[outputKey],
|
|
metadata: {},
|
|
}));
|
|
}
|
|
// Now, with the final result of all the inputs from the `llmChain`, we can
|
|
// run the `combineDocumentChain` over them.
|
|
const newInputs = {
|
|
[this.combineDocumentChain.inputKey]: currentDocs,
|
|
...rest,
|
|
};
|
|
const result = await this.combineDocumentChain.call(newInputs, runManager?.getChild("combine_documents"));
|
|
// Return the intermediate steps results if the flag is set
|
|
if (this.returnIntermediateSteps) {
|
|
return { ...result, intermediateSteps };
|
|
}
|
|
return result;
|
|
}
|
|
_chainType() {
|
|
return "map_reduce_documents_chain";
|
|
}
|
|
static async deserialize(data) {
|
|
if (!data.llm_chain) {
|
|
throw new Error("Missing llm_chain");
|
|
}
|
|
if (!data.combine_document_chain) {
|
|
throw new Error("Missing combine_document_chain");
|
|
}
|
|
return new MapReduceDocumentsChain({
|
|
llmChain: await LLMChain.deserialize(data.llm_chain),
|
|
combineDocumentChain: await StuffDocumentsChain.deserialize(data.combine_document_chain),
|
|
});
|
|
}
|
|
serialize() {
|
|
return {
|
|
_type: this._chainType(),
|
|
llm_chain: this.llmChain.serialize(),
|
|
combine_document_chain: this.combineDocumentChain.serialize(),
|
|
};
|
|
}
|
|
}
|
|
/**
|
|
* Combine documents by doing a first pass and then refining on more documents.
|
|
* @augments BaseChain
|
|
* @augments RefineDocumentsChainInput
|
|
*/
|
|
export class RefineDocumentsChain extends BaseChain {
|
|
static lc_name() {
|
|
return "RefineDocumentsChain";
|
|
}
|
|
get defaultDocumentPrompt() {
|
|
return new PromptTemplate({
|
|
inputVariables: ["page_content"],
|
|
template: "{page_content}",
|
|
});
|
|
}
|
|
get inputKeys() {
|
|
return [
|
|
...new Set([
|
|
this.inputKey,
|
|
...this.llmChain.inputKeys,
|
|
...this.refineLLMChain.inputKeys,
|
|
]),
|
|
].filter((key) => key !== this.documentVariableName && key !== this.initialResponseName);
|
|
}
|
|
get outputKeys() {
|
|
return [this.outputKey];
|
|
}
|
|
constructor(fields) {
|
|
super(fields);
|
|
Object.defineProperty(this, "llmChain", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "inputKey", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "input_documents"
|
|
});
|
|
Object.defineProperty(this, "outputKey", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "output_text"
|
|
});
|
|
Object.defineProperty(this, "documentVariableName", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "context"
|
|
});
|
|
Object.defineProperty(this, "initialResponseName", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "existing_answer"
|
|
});
|
|
Object.defineProperty(this, "refineLLMChain", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "documentPrompt", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: this.defaultDocumentPrompt
|
|
});
|
|
this.llmChain = fields.llmChain;
|
|
this.refineLLMChain = fields.refineLLMChain;
|
|
this.documentVariableName =
|
|
fields.documentVariableName ?? this.documentVariableName;
|
|
this.inputKey = fields.inputKey ?? this.inputKey;
|
|
this.outputKey = fields.outputKey ?? this.outputKey;
|
|
this.documentPrompt = fields.documentPrompt ?? this.documentPrompt;
|
|
this.initialResponseName =
|
|
fields.initialResponseName ?? this.initialResponseName;
|
|
}
|
|
/** @ignore */
|
|
async _constructInitialInputs(doc, rest) {
|
|
const baseInfo = {
|
|
page_content: doc.pageContent,
|
|
...doc.metadata,
|
|
};
|
|
const documentInfo = {};
|
|
this.documentPrompt.inputVariables.forEach((value) => {
|
|
documentInfo[value] = baseInfo[value];
|
|
});
|
|
const baseInputs = {
|
|
[this.documentVariableName]: await this.documentPrompt.format({
|
|
...documentInfo,
|
|
}),
|
|
};
|
|
const inputs = { ...baseInputs, ...rest };
|
|
return inputs;
|
|
}
|
|
/** @ignore */
|
|
async _constructRefineInputs(doc, res) {
|
|
const baseInfo = {
|
|
page_content: doc.pageContent,
|
|
...doc.metadata,
|
|
};
|
|
const documentInfo = {};
|
|
this.documentPrompt.inputVariables.forEach((value) => {
|
|
documentInfo[value] = baseInfo[value];
|
|
});
|
|
const baseInputs = {
|
|
[this.documentVariableName]: await this.documentPrompt.format({
|
|
...documentInfo,
|
|
}),
|
|
};
|
|
const inputs = { [this.initialResponseName]: res, ...baseInputs };
|
|
return inputs;
|
|
}
|
|
/** @ignore */
|
|
async _call(values, runManager) {
|
|
if (!(this.inputKey in values)) {
|
|
throw new Error(`Document key ${this.inputKey} not found.`);
|
|
}
|
|
const { [this.inputKey]: docs, ...rest } = values;
|
|
const currentDocs = docs;
|
|
const initialInputs = await this._constructInitialInputs(currentDocs[0], rest);
|
|
let res = await this.llmChain.predict({ ...initialInputs }, runManager?.getChild("answer"));
|
|
const refineSteps = [res];
|
|
for (let i = 1; i < currentDocs.length; i += 1) {
|
|
const refineInputs = await this._constructRefineInputs(currentDocs[i], res);
|
|
const inputs = { ...refineInputs, ...rest };
|
|
res = await this.refineLLMChain.predict({ ...inputs }, runManager?.getChild("refine"));
|
|
refineSteps.push(res);
|
|
}
|
|
return { [this.outputKey]: res };
|
|
}
|
|
_chainType() {
|
|
return "refine_documents_chain";
|
|
}
|
|
static async deserialize(data) {
|
|
const SerializedLLMChain = data.llm_chain;
|
|
if (!SerializedLLMChain) {
|
|
throw new Error("Missing llm_chain");
|
|
}
|
|
const SerializedRefineDocumentChain = data.refine_llm_chain;
|
|
if (!SerializedRefineDocumentChain) {
|
|
throw new Error("Missing refine_llm_chain");
|
|
}
|
|
return new RefineDocumentsChain({
|
|
llmChain: await LLMChain.deserialize(SerializedLLMChain),
|
|
refineLLMChain: await LLMChain.deserialize(SerializedRefineDocumentChain),
|
|
});
|
|
}
|
|
serialize() {
|
|
return {
|
|
_type: this._chainType(),
|
|
llm_chain: this.llmChain.serialize(),
|
|
refine_llm_chain: this.refineLLMChain.serialize(),
|
|
};
|
|
}
|
|
}
|