agsamantha/node_modules/@langchain/community/dist/chat_models/premai.js
2024-10-02 15:15:21 -05:00

393 lines
14 KiB
JavaScript

import { AIMessage, AIMessageChunk, ChatMessage, ChatMessageChunk, HumanMessageChunk, } from "@langchain/core/messages";
import { BaseChatModel, } from "@langchain/core/language_models/chat_models";
import Prem from "@premai/prem-sdk";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { ChatGenerationChunk, } from "@langchain/core/outputs";
function extractGenericMessageCustomRole(message) {
if (message.role !== "assistant" && message.role !== "user") {
console.warn(`Unknown message role: ${message.role}`);
}
return message.role;
}
export function messageToPremRole(message) {
const type = message._getType();
switch (type) {
case "ai":
return "assistant";
case "human":
return "user";
case "generic": {
if (!ChatMessage.isInstance(message))
throw new Error("Invalid generic chat message");
return extractGenericMessageCustomRole(message);
}
default:
throw new Error(`Unknown message type: ${type}`);
}
}
function convertMessagesToPremParams(messages) {
return messages.map((message) => {
if (typeof message.content !== "string") {
throw new Error("Non string message content not supported");
}
return {
role: messageToPremRole(message),
content: message.content,
name: message.name,
function_call: message.additional_kwargs.function_call,
};
});
}
function premResponseToChatMessage(message) {
switch (message.role) {
case "assistant":
return new AIMessage(message.content || "");
default:
return new ChatMessage(message.content || "", message.role ?? "unknown");
}
}
function _convertDeltaToMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta) {
const { role } = delta;
const content = delta.content ?? "";
let additional_kwargs;
if (delta.function_call) {
additional_kwargs = {
function_call: delta.function_call,
};
}
else {
additional_kwargs = {};
}
if (role === "user") {
return new HumanMessageChunk({ content });
}
else if (role === "assistant") {
return new AIMessageChunk({ content, additional_kwargs });
}
else {
return new ChatMessageChunk({ content, role });
}
}
/**
* Integration with a chat model.
*/
export class ChatPrem extends BaseChatModel {
// Used for tracing, replace with the same name as your class
static lc_name() {
return "ChatPrem";
}
/**
* Replace with any secrets this class passes to `super`.
* See {@link ../../langchain-cohere/src/chat_model.ts} for
* an example.
*/
get lc_secrets() {
return {
apiKey: "PREM_API_KEY",
};
}
get lc_aliases() {
return {
apiKey: "PREM_API_KEY",
};
}
constructor(fields) {
super(fields ?? {});
Object.defineProperty(this, "client", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "apiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "project_id", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "session_id", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "messages", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "model", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "system_prompt", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "frequency_penalty", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "logit_bias", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "max_tokens", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "n", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "presence_penalty", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "response_format", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "seed", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "stop", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "temperature", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "top_p", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "tools", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "user", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "streaming", {
enumerable: true,
configurable: true,
writable: true,
value: false
});
Object.defineProperty(this, "lc_serializable", {
enumerable: true,
configurable: true,
writable: true,
value: true
});
const apiKey = fields?.apiKey ?? getEnvironmentVariable("PREM_API_KEY");
if (!apiKey) {
throw new Error(`Prem API key not found. Please set the PREM_API_KEY environment variable or provide the key into "apiKey"`);
}
const projectId = fields?.project_id ??
parseInt(getEnvironmentVariable("PREM_PROJECT_ID") ?? "-1", 10);
if (!projectId || projectId === -1 || typeof projectId !== "number") {
throw new Error(`Prem project ID not found. Please set the PREM_PROJECT_ID environment variable or provide the key into "project_id"`);
}
this.client = new Prem({
apiKey,
});
this.project_id = projectId;
this.session_id = fields?.session_id ?? this.session_id;
this.messages = fields?.messages ?? this.messages;
this.model = fields?.model ?? this.model;
this.system_prompt = fields?.system_prompt ?? this.system_prompt;
this.frequency_penalty =
fields?.frequency_penalty ?? this.frequency_penalty;
this.logit_bias = fields?.logit_bias ?? this.logit_bias;
this.max_tokens = fields?.max_tokens ?? this.max_tokens;
this.n = fields?.n ?? this.n;
this.presence_penalty = fields?.presence_penalty ?? this.presence_penalty;
this.response_format = fields?.response_format ?? this.response_format;
this.seed = fields?.seed ?? this.seed;
this.stop = fields?.stop ?? this.stop;
this.temperature = fields?.temperature ?? this.temperature;
this.top_p = fields?.top_p ?? this.top_p;
this.tools = fields?.tools ?? this.tools;
this.user = fields?.user ?? this.user;
this.streaming = fields?.streaming ?? this.streaming;
}
// Replace
_llmType() {
return "prem";
}
async completionWithRetry(request,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
options) {
return this.caller.call(async () => this.client.chat.completions.create(request, options));
}
invocationParams(options) {
const params = super.invocationParams(options);
return {
...params,
project_id: this.project_id,
session_id: this.session_id,
messages: this.messages,
model: this.model,
system_prompt: this.system_prompt,
frequency_penalty: this.frequency_penalty,
logit_bias: this.logit_bias,
max_tokens: this.max_tokens,
n: this.n,
presence_penalty: this.presence_penalty,
response_format: this.response_format,
seed: this.seed,
stop: this.stop,
temperature: this.temperature,
top_p: this.top_p,
tools: this.tools,
user: this.user,
streaming: this.streaming,
stream: this.streaming,
};
}
/**
* Implement to support streaming.
* Should yield chunks iteratively.
*/
async *_streamResponseChunks(messages, options, runManager) {
const params = this.invocationParams(options);
const messagesMapped = convertMessagesToPremParams(messages);
// All models have a built-in `this.caller` property for retries
const stream = await this.caller.call(async () => this.completionWithRetry({
...params,
messages: messagesMapped,
stream: true,
}, params));
for await (const data of stream) {
const choice = data?.choices[0];
if (!choice) {
continue;
}
const chunk = new ChatGenerationChunk({
message: _convertDeltaToMessageChunk(choice.delta ?? {}),
text: choice.delta.content ?? "",
generationInfo: {
finishReason: choice.finish_reason,
},
});
yield chunk;
void runManager?.handleLLMNewToken(chunk.text ?? "");
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}
}
/** @ignore */
_combineLLMOutput() {
return [];
}
async _generate(messages, options, runManager) {
const tokenUsage = {};
const params = this.invocationParams(options);
const messagesMapped = convertMessagesToPremParams(messages);
if (params.streaming) {
const stream = this._streamResponseChunks(messages, options, runManager);
const finalChunks = {};
for await (const chunk of stream) {
const index = chunk.generationInfo?.completion ?? 0;
if (finalChunks[index] === undefined) {
finalChunks[index] = chunk;
}
else {
finalChunks[index] = finalChunks[index].concat(chunk);
}
}
const generations = Object.entries(finalChunks)
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
.map(([_, value]) => value);
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } };
}
else {
const data = await this.completionWithRetry({
...params,
stream: false,
messages: messagesMapped,
}, {
signal: options?.signal,
});
if ("usage" in data && data.usage) {
const { completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens, } = data.usage;
if (completionTokens) {
tokenUsage.completionTokens =
(tokenUsage.completionTokens ?? 0) + completionTokens;
}
if (promptTokens) {
tokenUsage.promptTokens =
(tokenUsage.promptTokens ?? 0) + promptTokens;
}
if (totalTokens) {
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
}
}
const generations = [];
if ("choices" in data && data.choices) {
for (const part of data
.choices) {
const text = part.message?.content ?? "";
const generation = {
text,
message: premResponseToChatMessage(part.message ?? { role: "assistant" }),
};
generation.generationInfo = {
...(part.finish_reason
? { finish_reason: part.finish_reason }
: {}),
...(part.logprobs ? { logprobs: part.logprobs } : {}),
};
generations.push(generation);
}
}
return {
generations,
llmOutput: { tokenUsage },
};
}
}
}