agsamantha/node_modules/@langchain/community/dist/chat_models/llama_cpp.js

292 lines
11 KiB
JavaScript
Raw Normal View History

2024-10-02 15:15:21 -05:00
/* eslint-disable import/no-extraneous-dependencies */
import { LlamaChatSession, } from "node-llama-cpp";
import { SimpleChatModel, } from "@langchain/core/language_models/chat_models";
import { AIMessageChunk, ChatMessage, } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";
import { createLlamaModel, createLlamaContext, } from "../utils/llama_cpp.js";
/**
* To use this model you need to have the `node-llama-cpp` module installed.
* This can be installed using `npm install -S node-llama-cpp` and the minimum
* version supported in version 2.0.0.
* This also requires that have a locally built version of Llama2 installed.
* @example
* ```typescript
* // Initialize the ChatLlamaCpp model with the path to the model binary file.
* const model = new ChatLlamaCpp({
* modelPath: "/Replace/with/path/to/your/model/gguf-llama2-q4_0.bin",
* temperature: 0.5,
* });
*
* // Call the model with a message and await the response.
* const response = await model.invoke([
* new HumanMessage({ content: "My name is John." }),
* ]);
*
* // Log the response to the console.
* console.log({ response });
*
* ```
*/
export class ChatLlamaCpp extends SimpleChatModel {
static lc_name() {
return "ChatLlamaCpp";
}
constructor(inputs) {
super(inputs);
Object.defineProperty(this, "maxTokens", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "temperature", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "topK", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "topP", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "trimWhitespaceSuffix", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "_model", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "_context", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "_session", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "lc_serializable", {
enumerable: true,
configurable: true,
writable: true,
value: true
});
this.maxTokens = inputs?.maxTokens;
this.temperature = inputs?.temperature;
this.topK = inputs?.topK;
this.topP = inputs?.topP;
this.trimWhitespaceSuffix = inputs?.trimWhitespaceSuffix;
this._model = createLlamaModel(inputs);
this._context = createLlamaContext(this._model, inputs);
this._session = null;
}
_llmType() {
return "llama2_cpp";
}
/** @ignore */
_combineLLMOutput() {
return {};
}
invocationParams() {
return {
maxTokens: this.maxTokens,
temperature: this.temperature,
topK: this.topK,
topP: this.topP,
trimWhitespaceSuffix: this.trimWhitespaceSuffix,
};
}
/** @ignore */
async _call(messages, options, runManager) {
let prompt = "";
if (messages.length > 1) {
// We need to build a new _session
prompt = this._buildSession(messages);
}
else if (!this._session) {
prompt = this._buildSession(messages);
}
else {
if (typeof messages[0].content !== "string") {
throw new Error("ChatLlamaCpp does not support non-string message content in sessions.");
}
// If we already have a session then we should just have a single prompt
prompt = messages[0].content;
}
try {
const promptOptions = {
signal: options.signal,
onToken: async (tokens) => {
options.onToken?.(tokens);
await runManager?.handleLLMNewToken(this._context.decode(tokens));
},
maxTokens: this?.maxTokens,
temperature: this?.temperature,
topK: this?.topK,
topP: this?.topP,
trimWhitespaceSuffix: this?.trimWhitespaceSuffix,
};
// @ts-expect-error - TS2531: Object is possibly 'null'.
const completion = await this._session.prompt(prompt, promptOptions);
return completion;
}
catch (e) {
if (typeof e === "object") {
const error = e;
if (error.message === "AbortError") {
throw error;
}
}
throw new Error("Error getting prompt completion.");
}
}
async *_streamResponseChunks(input, _options, runManager) {
const promptOptions = {
temperature: this?.temperature,
topK: this?.topK,
topP: this?.topP,
};
const prompt = this._buildPrompt(input);
const stream = await this.caller.call(async () => this._context.evaluate(this._context.encode(prompt), promptOptions));
for await (const chunk of stream) {
yield new ChatGenerationChunk({
text: this._context.decode([chunk]),
message: new AIMessageChunk({
content: this._context.decode([chunk]),
}),
generationInfo: {},
});
await runManager?.handleLLMNewToken(this._context.decode([chunk]) ?? "");
}
}
// This constructs a new session if we need to adding in any sys messages or previous chats
_buildSession(messages) {
let prompt = "";
let sysMessage = "";
let noSystemMessages = [];
let interactions = [];
// Let's see if we have a system message
if (messages.findIndex((msg) => msg._getType() === "system") !== -1) {
const sysMessages = messages.filter((message) => message._getType() === "system");
const systemMessageContent = sysMessages[sysMessages.length - 1].content;
if (typeof systemMessageContent !== "string") {
throw new Error("ChatLlamaCpp does not support non-string message content in sessions.");
}
// Only use the last provided system message
sysMessage = systemMessageContent;
// Now filter out the system messages
noSystemMessages = messages.filter((message) => message._getType() !== "system");
}
else {
noSystemMessages = messages;
}
// Lets see if we just have a prompt left or are their previous interactions?
if (noSystemMessages.length > 1) {
// Is the last message a prompt?
if (noSystemMessages[noSystemMessages.length - 1]._getType() === "human") {
const finalMessageContent = noSystemMessages[noSystemMessages.length - 1].content;
if (typeof finalMessageContent !== "string") {
throw new Error("ChatLlamaCpp does not support non-string message content in sessions.");
}
prompt = finalMessageContent;
interactions = this._convertMessagesToInteractions(noSystemMessages.slice(0, noSystemMessages.length - 1));
}
else {
interactions = this._convertMessagesToInteractions(noSystemMessages);
}
}
else {
if (typeof noSystemMessages[0].content !== "string") {
throw new Error("ChatLlamaCpp does not support non-string message content in sessions.");
}
// If there was only a single message we assume it's a prompt
prompt = noSystemMessages[0].content;
}
// Now lets construct a session according to what we got
if (sysMessage !== "" && interactions.length > 0) {
this._session = new LlamaChatSession({
context: this._context,
conversationHistory: interactions,
systemPrompt: sysMessage,
});
}
else if (sysMessage !== "" && interactions.length === 0) {
this._session = new LlamaChatSession({
context: this._context,
systemPrompt: sysMessage,
});
}
else if (sysMessage === "" && interactions.length > 0) {
this._session = new LlamaChatSession({
context: this._context,
conversationHistory: interactions,
});
}
else {
this._session = new LlamaChatSession({
context: this._context,
});
}
return prompt;
}
// This builds a an array of interactions
_convertMessagesToInteractions(messages) {
const result = [];
for (let i = 0; i < messages.length; i += 2) {
if (i + 1 < messages.length) {
const prompt = messages[i].content;
const response = messages[i + 1].content;
if (typeof prompt !== "string" || typeof response !== "string") {
throw new Error("ChatLlamaCpp does not support non-string message content.");
}
result.push({
prompt,
response,
});
}
}
return result;
}
_buildPrompt(input) {
const prompt = input
.map((message) => {
let messageText;
if (message._getType() === "human") {
messageText = `[INST] ${message.content} [/INST]`;
}
else if (message._getType() === "ai") {
messageText = message.content;
}
else if (message._getType() === "system") {
messageText = `<<SYS>> ${message.content} <</SYS>>`;
}
else if (ChatMessage.isInstance(message)) {
messageText = `\n\n${message.role[0].toUpperCase()}${message.role.slice(1)}: ${message.content}`;
}
else {
console.warn(`Unsupported message type passed to llama_cpp: "${message._getType()}"`);
messageText = "";
}
return messageText;
})
.join("\n");
return prompt;
}
}