199 lines
7.1 KiB
JavaScript
199 lines
7.1 KiB
JavaScript
|
import { getPromptInputKey, } from "@langchain/core/memory";
|
||
|
import { getBufferString } from "@langchain/core/messages";
|
||
|
import { InMemoryEntityStore } from "./stores/entity/in_memory.js";
|
||
|
import { LLMChain } from "../chains/llm_chain.js";
|
||
|
import { ENTITY_EXTRACTION_PROMPT, ENTITY_SUMMARIZATION_PROMPT, } from "./prompt.js";
|
||
|
import { BaseChatMemory } from "./chat_memory.js";
|
||
|
// Entity extractor & summarizer to memory.
|
||
|
/**
|
||
|
* Class for managing entity extraction and summarization to memory in
|
||
|
* chatbot applications. Extends the BaseChatMemory class and implements
|
||
|
* the EntityMemoryInput interface.
|
||
|
* @example
|
||
|
* ```typescript
|
||
|
* const memory = new EntityMemory({
|
||
|
* llm: new ChatOpenAI({ temperature: 0 }),
|
||
|
* chatHistoryKey: "history",
|
||
|
* entitiesKey: "entities",
|
||
|
* });
|
||
|
* const model = new ChatOpenAI({ temperature: 0.9 });
|
||
|
* const chain = new LLMChain({
|
||
|
* llm: model,
|
||
|
* prompt: ENTITY_MEMORY_CONVERSATION_TEMPLATE,
|
||
|
* memory,
|
||
|
* });
|
||
|
*
|
||
|
* const res1 = await chain.call({ input: "Hi! I'm Jim." });
|
||
|
* console.log({
|
||
|
* res1,
|
||
|
* memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
|
||
|
* });
|
||
|
*
|
||
|
* const res2 = await chain.call({
|
||
|
* input: "I work in construction. What about you?",
|
||
|
* });
|
||
|
* console.log({
|
||
|
* res2,
|
||
|
* memory: await memory.loadMemoryVariables({ input: "Who is Jim?" }),
|
||
|
* });
|
||
|
*
|
||
|
* ```
|
||
|
*/
|
||
|
export class EntityMemory extends BaseChatMemory {
|
||
|
constructor(fields) {
|
||
|
super({
|
||
|
chatHistory: fields.chatHistory,
|
||
|
returnMessages: fields.returnMessages ?? false,
|
||
|
inputKey: fields.inputKey,
|
||
|
outputKey: fields.outputKey,
|
||
|
});
|
||
|
Object.defineProperty(this, "entityExtractionChain", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "entitySummarizationChain", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "entityStore", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "entityCache", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: []
|
||
|
});
|
||
|
Object.defineProperty(this, "k", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: 3
|
||
|
});
|
||
|
Object.defineProperty(this, "chatHistoryKey", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "history"
|
||
|
});
|
||
|
Object.defineProperty(this, "llm", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "entitiesKey", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "entities"
|
||
|
});
|
||
|
Object.defineProperty(this, "humanPrefix", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "aiPrefix", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
this.llm = fields.llm;
|
||
|
this.humanPrefix = fields.humanPrefix;
|
||
|
this.aiPrefix = fields.aiPrefix;
|
||
|
this.chatHistoryKey = fields.chatHistoryKey ?? this.chatHistoryKey;
|
||
|
this.entitiesKey = fields.entitiesKey ?? this.entitiesKey;
|
||
|
this.entityExtractionChain = new LLMChain({
|
||
|
llm: this.llm,
|
||
|
prompt: fields.entityExtractionPrompt ?? ENTITY_EXTRACTION_PROMPT,
|
||
|
});
|
||
|
this.entitySummarizationChain = new LLMChain({
|
||
|
llm: this.llm,
|
||
|
prompt: fields.entitySummarizationPrompt ?? ENTITY_SUMMARIZATION_PROMPT,
|
||
|
});
|
||
|
this.entityStore = fields.entityStore ?? new InMemoryEntityStore();
|
||
|
this.entityCache = fields.entityCache ?? this.entityCache;
|
||
|
this.k = fields.k ?? this.k;
|
||
|
}
|
||
|
get memoryKeys() {
|
||
|
return [this.chatHistoryKey];
|
||
|
}
|
||
|
// Will always return list of memory variables.
|
||
|
get memoryVariables() {
|
||
|
return [this.entitiesKey, this.chatHistoryKey];
|
||
|
}
|
||
|
// Return history buffer.
|
||
|
/**
|
||
|
* Method to load memory variables and perform entity extraction.
|
||
|
* @param inputs Input values for the method.
|
||
|
* @returns Promise resolving to an object containing memory variables.
|
||
|
*/
|
||
|
async loadMemoryVariables(inputs) {
|
||
|
const promptInputKey = this.inputKey ?? getPromptInputKey(inputs, this.memoryVariables);
|
||
|
const messages = await this.chatHistory.getMessages();
|
||
|
const serializedMessages = getBufferString(messages.slice(-this.k * 2), this.humanPrefix, this.aiPrefix);
|
||
|
const output = await this.entityExtractionChain.predict({
|
||
|
history: serializedMessages,
|
||
|
input: inputs[promptInputKey],
|
||
|
});
|
||
|
const entities = output.trim() === "NONE" ? [] : output.split(",").map((w) => w.trim());
|
||
|
const entitySummaries = {};
|
||
|
for (const entity of entities) {
|
||
|
entitySummaries[entity] = await this.entityStore.get(entity, "No current information known.");
|
||
|
}
|
||
|
this.entityCache = [...entities];
|
||
|
const buffer = this.returnMessages
|
||
|
? messages.slice(-this.k * 2)
|
||
|
: serializedMessages;
|
||
|
return {
|
||
|
[this.chatHistoryKey]: buffer,
|
||
|
[this.entitiesKey]: entitySummaries,
|
||
|
};
|
||
|
}
|
||
|
// Save context from this conversation to buffer.
|
||
|
/**
|
||
|
* Method to save the context from a conversation to a buffer and perform
|
||
|
* entity summarization.
|
||
|
* @param inputs Input values for the method.
|
||
|
* @param outputs Output values from the method.
|
||
|
* @returns Promise resolving to void.
|
||
|
*/
|
||
|
async saveContext(inputs, outputs) {
|
||
|
await super.saveContext(inputs, outputs);
|
||
|
const promptInputKey = this.inputKey ?? getPromptInputKey(inputs, this.memoryVariables);
|
||
|
const messages = await this.chatHistory.getMessages();
|
||
|
const serializedMessages = getBufferString(messages.slice(-this.k * 2), this.humanPrefix, this.aiPrefix);
|
||
|
const inputData = inputs[promptInputKey];
|
||
|
for (const entity of this.entityCache) {
|
||
|
const existingSummary = await this.entityStore.get(entity, "No current information known.");
|
||
|
const output = await this.entitySummarizationChain.predict({
|
||
|
summary: existingSummary,
|
||
|
entity,
|
||
|
history: serializedMessages,
|
||
|
input: inputData,
|
||
|
});
|
||
|
if (output.trim() !== "UNCHANGED") {
|
||
|
await this.entityStore.set(entity, output.trim());
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// Clear memory contents.
|
||
|
/**
|
||
|
* Method to clear the memory contents.
|
||
|
* @returns Promise resolving to void.
|
||
|
*/
|
||
|
async clear() {
|
||
|
await super.clear();
|
||
|
await this.entityStore.clear();
|
||
|
}
|
||
|
}
|