agsamantha/node_modules/langchain/dist/chains/analyze_documents_chain.js
2024-10-02 15:15:21 -05:00

95 lines
3.2 KiB
JavaScript

import { BaseChain } from "./base.js";
import { RecursiveCharacterTextSplitter, } from "../text_splitter.js";
/**
* Chain that combines documents by stuffing into context.
* @augments BaseChain
* @augments StuffDocumentsChainInput
* @example
* ```typescript
* const model = new ChatOpenAI({ temperature: 0 });
* const combineDocsChain = loadSummarizationChain(model);
* const chain = new AnalyzeDocumentChain({
* combineDocumentsChain: combineDocsChain,
* });
*
* // Read the text from a file (this is a placeholder for actual file reading)
* const text = readTextFromFile("state_of_the_union.txt");
*
* // Invoke the chain to analyze the document
* const res = await chain.call({
* input_document: text,
* });
*
* console.log({ res });
* ```
*/
export class AnalyzeDocumentChain extends BaseChain {
static lc_name() {
return "AnalyzeDocumentChain";
}
constructor(fields) {
super(fields);
Object.defineProperty(this, "inputKey", {
enumerable: true,
configurable: true,
writable: true,
value: "input_document"
});
Object.defineProperty(this, "combineDocumentsChain", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "textSplitter", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.combineDocumentsChain = fields.combineDocumentsChain;
this.inputKey = fields.inputKey ?? this.inputKey;
this.textSplitter =
fields.textSplitter ?? new RecursiveCharacterTextSplitter();
}
get inputKeys() {
return [this.inputKey];
}
get outputKeys() {
return this.combineDocumentsChain.outputKeys;
}
/** @ignore */
async _call(values, runManager) {
if (!(this.inputKey in values)) {
throw new Error(`Document key ${this.inputKey} not found.`);
}
const { [this.inputKey]: doc, ...rest } = values;
const currentDoc = doc;
const currentDocs = await this.textSplitter.createDocuments([currentDoc]);
const newInputs = { input_documents: currentDocs, ...rest };
const result = await this.combineDocumentsChain.call(newInputs, runManager?.getChild("combine_documents"));
return result;
}
_chainType() {
return "analyze_document_chain";
}
static async deserialize(data, values) {
if (!("text_splitter" in values)) {
throw new Error(`Need to pass in a text_splitter to deserialize AnalyzeDocumentChain.`);
}
const { text_splitter } = values;
if (!data.combine_document_chain) {
throw new Error(`Need to pass in a combine_document_chain to deserialize AnalyzeDocumentChain.`);
}
return new AnalyzeDocumentChain({
combineDocumentsChain: await BaseChain.deserialize(data.combine_document_chain),
textSplitter: text_splitter,
});
}
serialize() {
return {
_type: this._chainType(),
combine_document_chain: this.combineDocumentsChain.serialize(),
};
}
}