agsamantha/node_modules/langchain/dist/memory/entity_memory.js

199 lines
7.1 KiB
JavaScript
Raw Normal View History

2024-10-02 15:15:21 -05:00
import { getPromptInputKey, } from "@langchain/core/memory";
import { getBufferString } from "@langchain/core/messages";
import { InMemoryEntityStore } from "./stores/entity/in_memory.js";
import { LLMChain } from "../chains/llm_chain.js";
import { ENTITY_EXTRACTION_PROMPT, ENTITY_SUMMARIZATION_PROMPT, } from "./prompt.js";
import { BaseChatMemory } from "./chat_memory.js";
// Entity extractor & summarizer to memory.
/**
* Class for managing entity extraction and summarization to memory in
* chatbot applications. Extends the BaseChatMemory class and implements
* the EntityMemoryInput interface.
* @example
* ```typescript
* const memory = new EntityMemory({
* llm: new ChatOpenAI({ temperature: 0 }),
* chatHistoryKey: "history",
* entitiesKey: "entities",
* });
* const model = new ChatOpenAI({ temperature: 0.9 });
* const chain = new LLMChain({
* llm: model,
* prompt: ENTITY_MEMORY_CONVERSATION_TEMPLATE,
* memory,
* });
*
* const res1 = await chain.call({ input: "Hi! I'm Jim." });
* console.log({
* res1,
* memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
* });
*
* const res2 = await chain.call({
* input: "I work in construction. What about you?",
* });
* console.log({
* res2,
* memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
* });
*
* ```
*/
export class EntityMemory extends BaseChatMemory {
constructor(fields) {
super({
chatHistory: fields.chatHistory,
returnMessages: fields.returnMessages ?? false,
inputKey: fields.inputKey,
outputKey: fields.outputKey,
});
Object.defineProperty(this, "entityExtractionChain", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "entitySummarizationChain", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "entityStore", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "entityCache", {
enumerable: true,
configurable: true,
writable: true,
value: []
});
Object.defineProperty(this, "k", {
enumerable: true,
configurable: true,
writable: true,
value: 3
});
Object.defineProperty(this, "chatHistoryKey", {
enumerable: true,
configurable: true,
writable: true,
value: "history"
});
Object.defineProperty(this, "llm", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "entitiesKey", {
enumerable: true,
configurable: true,
writable: true,
value: "entities"
});
Object.defineProperty(this, "humanPrefix", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "aiPrefix", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.llm = fields.llm;
this.humanPrefix = fields.humanPrefix;
this.aiPrefix = fields.aiPrefix;
this.chatHistoryKey = fields.chatHistoryKey ?? this.chatHistoryKey;
this.entitiesKey = fields.entitiesKey ?? this.entitiesKey;
this.entityExtractionChain = new LLMChain({
llm: this.llm,
prompt: fields.entityExtractionPrompt ?? ENTITY_EXTRACTION_PROMPT,
});
this.entitySummarizationChain = new LLMChain({
llm: this.llm,
prompt: fields.entitySummarizationPrompt ?? ENTITY_SUMMARIZATION_PROMPT,
});
this.entityStore = fields.entityStore ?? new InMemoryEntityStore();
this.entityCache = fields.entityCache ?? this.entityCache;
this.k = fields.k ?? this.k;
}
get memoryKeys() {
return [this.chatHistoryKey];
}
// Will always return list of memory variables.
get memoryVariables() {
return [this.entitiesKey, this.chatHistoryKey];
}
// Return history buffer.
/**
* Method to load memory variables and perform entity extraction.
* @param inputs Input values for the method.
* @returns Promise resolving to an object containing memory variables.
*/
async loadMemoryVariables(inputs) {
const promptInputKey = this.inputKey ?? getPromptInputKey(inputs, this.memoryVariables);
const messages = await this.chatHistory.getMessages();
const serializedMessages = getBufferString(messages.slice(-this.k * 2), this.humanPrefix, this.aiPrefix);
const output = await this.entityExtractionChain.predict({
history: serializedMessages,
input: inputs[promptInputKey],
});
const entities = output.trim() === "NONE" ? [] : output.split(",").map((w) => w.trim());
const entitySummaries = {};
for (const entity of entities) {
entitySummaries[entity] = await this.entityStore.get(entity, "No current information known.");
}
this.entityCache = [...entities];
const buffer = this.returnMessages
? messages.slice(-this.k * 2)
: serializedMessages;
return {
[this.chatHistoryKey]: buffer,
[this.entitiesKey]: entitySummaries,
};
}
// Save context from this conversation to buffer.
/**
* Method to save the context from a conversation to a buffer and perform
* entity summarization.
* @param inputs Input values for the method.
* @param outputs Output values from the method.
* @returns Promise resolving to void.
*/
async saveContext(inputs, outputs) {
await super.saveContext(inputs, outputs);
const promptInputKey = this.inputKey ?? getPromptInputKey(inputs, this.memoryVariables);
const messages = await this.chatHistory.getMessages();
const serializedMessages = getBufferString(messages.slice(-this.k * 2), this.humanPrefix, this.aiPrefix);
const inputData = inputs[promptInputKey];
for (const entity of this.entityCache) {
const existingSummary = await this.entityStore.get(entity, "No current information known.");
const output = await this.entitySummarizationChain.predict({
summary: existingSummary,
entity,
history: serializedMessages,
input: inputData,
});
if (output.trim() !== "UNCHANGED") {
await this.entityStore.set(entity, output.trim());
}
}
}
// Clear memory contents.
/**
* Method to clear the memory contents.
* @returns Promise resolving to void.
*/
async clear() {
await super.clear();
await this.entityStore.clear();
}
}