394 lines
14 KiB
JavaScript
394 lines
14 KiB
JavaScript
|
import { AIMessage, AIMessageChunk, ChatMessage, ChatMessageChunk, HumanMessageChunk, } from "@langchain/core/messages";
|
||
|
import { BaseChatModel, } from "@langchain/core/language_models/chat_models";
|
||
|
import Prem from "@premai/prem-sdk";
|
||
|
import { getEnvironmentVariable } from "@langchain/core/utils/env";
|
||
|
import { ChatGenerationChunk, } from "@langchain/core/outputs";
|
||
|
function extractGenericMessageCustomRole(message) {
|
||
|
if (message.role !== "assistant" && message.role !== "user") {
|
||
|
console.warn(`Unknown message role: ${message.role}`);
|
||
|
}
|
||
|
return message.role;
|
||
|
}
|
||
|
export function messageToPremRole(message) {
|
||
|
const type = message._getType();
|
||
|
switch (type) {
|
||
|
case "ai":
|
||
|
return "assistant";
|
||
|
case "human":
|
||
|
return "user";
|
||
|
case "generic": {
|
||
|
if (!ChatMessage.isInstance(message))
|
||
|
throw new Error("Invalid generic chat message");
|
||
|
return extractGenericMessageCustomRole(message);
|
||
|
}
|
||
|
default:
|
||
|
throw new Error(`Unknown message type: ${type}`);
|
||
|
}
|
||
|
}
|
||
|
function convertMessagesToPremParams(messages) {
|
||
|
return messages.map((message) => {
|
||
|
if (typeof message.content !== "string") {
|
||
|
throw new Error("Non string message content not supported");
|
||
|
}
|
||
|
return {
|
||
|
role: messageToPremRole(message),
|
||
|
content: message.content,
|
||
|
name: message.name,
|
||
|
function_call: message.additional_kwargs.function_call,
|
||
|
};
|
||
|
});
|
||
|
}
|
||
|
function premResponseToChatMessage(message) {
|
||
|
switch (message.role) {
|
||
|
case "assistant":
|
||
|
return new AIMessage(message.content || "");
|
||
|
default:
|
||
|
return new ChatMessage(message.content || "", message.role ?? "unknown");
|
||
|
}
|
||
|
}
|
||
|
function _convertDeltaToMessageChunk(
|
||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||
|
delta) {
|
||
|
const { role } = delta;
|
||
|
const content = delta.content ?? "";
|
||
|
let additional_kwargs;
|
||
|
if (delta.function_call) {
|
||
|
additional_kwargs = {
|
||
|
function_call: delta.function_call,
|
||
|
};
|
||
|
}
|
||
|
else {
|
||
|
additional_kwargs = {};
|
||
|
}
|
||
|
if (role === "user") {
|
||
|
return new HumanMessageChunk({ content });
|
||
|
}
|
||
|
else if (role === "assistant") {
|
||
|
return new AIMessageChunk({ content, additional_kwargs });
|
||
|
}
|
||
|
else {
|
||
|
return new ChatMessageChunk({ content, role });
|
||
|
}
|
||
|
}
|
||
|
/**
|
||
|
* Integration with a chat model.
|
||
|
*/
|
||
|
export class ChatPrem extends BaseChatModel {
|
||
|
// Used for tracing, replace with the same name as your class
|
||
|
static lc_name() {
|
||
|
return "ChatPrem";
|
||
|
}
|
||
|
/**
|
||
|
* Replace with any secrets this class passes to `super`.
|
||
|
* See {@link ../../langchain-cohere/src/chat_model.ts} for
|
||
|
* an example.
|
||
|
*/
|
||
|
get lc_secrets() {
|
||
|
return {
|
||
|
apiKey: "PREM_API_KEY",
|
||
|
};
|
||
|
}
|
||
|
get lc_aliases() {
|
||
|
return {
|
||
|
apiKey: "PREM_API_KEY",
|
||
|
};
|
||
|
}
|
||
|
constructor(fields) {
|
||
|
super(fields ?? {});
|
||
|
Object.defineProperty(this, "client", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "apiKey", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "project_id", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "session_id", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "messages", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "model", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "system_prompt", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "frequency_penalty", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "logit_bias", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "max_tokens", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "n", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "presence_penalty", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "response_format", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "seed", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "stop", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "temperature", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "top_p", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "tools", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "user", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "streaming", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: false
|
||
|
});
|
||
|
Object.defineProperty(this, "lc_serializable", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: true
|
||
|
});
|
||
|
const apiKey = fields?.apiKey ?? getEnvironmentVariable("PREM_API_KEY");
|
||
|
if (!apiKey) {
|
||
|
throw new Error(`Prem API key not found. Please set the PREM_API_KEY environment variable or provide the key into "apiKey"`);
|
||
|
}
|
||
|
const projectId = fields?.project_id ??
|
||
|
parseInt(getEnvironmentVariable("PREM_PROJECT_ID") ?? "-1", 10);
|
||
|
if (!projectId || projectId === -1 || typeof projectId !== "number") {
|
||
|
throw new Error(`Prem project ID not found. Please set the PREM_PROJECT_ID environment variable or provide the key into "project_id"`);
|
||
|
}
|
||
|
this.client = new Prem({
|
||
|
apiKey,
|
||
|
});
|
||
|
this.project_id = projectId;
|
||
|
this.session_id = fields?.session_id ?? this.session_id;
|
||
|
this.messages = fields?.messages ?? this.messages;
|
||
|
this.model = fields?.model ?? this.model;
|
||
|
this.system_prompt = fields?.system_prompt ?? this.system_prompt;
|
||
|
this.frequency_penalty =
|
||
|
fields?.frequency_penalty ?? this.frequency_penalty;
|
||
|
this.logit_bias = fields?.logit_bias ?? this.logit_bias;
|
||
|
this.max_tokens = fields?.max_tokens ?? this.max_tokens;
|
||
|
this.n = fields?.n ?? this.n;
|
||
|
this.presence_penalty = fields?.presence_penalty ?? this.presence_penalty;
|
||
|
this.response_format = fields?.response_format ?? this.response_format;
|
||
|
this.seed = fields?.seed ?? this.seed;
|
||
|
this.stop = fields?.stop ?? this.stop;
|
||
|
this.temperature = fields?.temperature ?? this.temperature;
|
||
|
this.top_p = fields?.top_p ?? this.top_p;
|
||
|
this.tools = fields?.tools ?? this.tools;
|
||
|
this.user = fields?.user ?? this.user;
|
||
|
this.streaming = fields?.streaming ?? this.streaming;
|
||
|
}
|
||
|
// Replace
|
||
|
_llmType() {
|
||
|
return "prem";
|
||
|
}
|
||
|
async completionWithRetry(request,
|
||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||
|
options) {
|
||
|
return this.caller.call(async () => this.client.chat.completions.create(request, options));
|
||
|
}
|
||
|
invocationParams(options) {
|
||
|
const params = super.invocationParams(options);
|
||
|
return {
|
||
|
...params,
|
||
|
project_id: this.project_id,
|
||
|
session_id: this.session_id,
|
||
|
messages: this.messages,
|
||
|
model: this.model,
|
||
|
system_prompt: this.system_prompt,
|
||
|
frequency_penalty: this.frequency_penalty,
|
||
|
logit_bias: this.logit_bias,
|
||
|
max_tokens: this.max_tokens,
|
||
|
n: this.n,
|
||
|
presence_penalty: this.presence_penalty,
|
||
|
response_format: this.response_format,
|
||
|
seed: this.seed,
|
||
|
stop: this.stop,
|
||
|
temperature: this.temperature,
|
||
|
top_p: this.top_p,
|
||
|
tools: this.tools,
|
||
|
user: this.user,
|
||
|
streaming: this.streaming,
|
||
|
stream: this.streaming,
|
||
|
};
|
||
|
}
|
||
|
/**
|
||
|
* Implement to support streaming.
|
||
|
* Should yield chunks iteratively.
|
||
|
*/
|
||
|
async *_streamResponseChunks(messages, options, runManager) {
|
||
|
const params = this.invocationParams(options);
|
||
|
const messagesMapped = convertMessagesToPremParams(messages);
|
||
|
// All models have a built-in `this.caller` property for retries
|
||
|
const stream = await this.caller.call(async () => this.completionWithRetry({
|
||
|
...params,
|
||
|
messages: messagesMapped,
|
||
|
stream: true,
|
||
|
}, params));
|
||
|
for await (const data of stream) {
|
||
|
const choice = data?.choices[0];
|
||
|
if (!choice) {
|
||
|
continue;
|
||
|
}
|
||
|
const chunk = new ChatGenerationChunk({
|
||
|
message: _convertDeltaToMessageChunk(choice.delta ?? {}),
|
||
|
text: choice.delta.content ?? "",
|
||
|
generationInfo: {
|
||
|
finishReason: choice.finish_reason,
|
||
|
},
|
||
|
});
|
||
|
yield chunk;
|
||
|
void runManager?.handleLLMNewToken(chunk.text ?? "");
|
||
|
}
|
||
|
if (options.signal?.aborted) {
|
||
|
throw new Error("AbortError");
|
||
|
}
|
||
|
}
|
||
|
/** @ignore */
|
||
|
_combineLLMOutput() {
|
||
|
return [];
|
||
|
}
|
||
|
async _generate(messages, options, runManager) {
|
||
|
const tokenUsage = {};
|
||
|
const params = this.invocationParams(options);
|
||
|
const messagesMapped = convertMessagesToPremParams(messages);
|
||
|
if (params.streaming) {
|
||
|
const stream = this._streamResponseChunks(messages, options, runManager);
|
||
|
const finalChunks = {};
|
||
|
for await (const chunk of stream) {
|
||
|
const index = chunk.generationInfo?.completion ?? 0;
|
||
|
if (finalChunks[index] === undefined) {
|
||
|
finalChunks[index] = chunk;
|
||
|
}
|
||
|
else {
|
||
|
finalChunks[index] = finalChunks[index].concat(chunk);
|
||
|
}
|
||
|
}
|
||
|
const generations = Object.entries(finalChunks)
|
||
|
.sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
|
||
|
.map(([_, value]) => value);
|
||
|
return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } };
|
||
|
}
|
||
|
else {
|
||
|
const data = await this.completionWithRetry({
|
||
|
...params,
|
||
|
stream: false,
|
||
|
messages: messagesMapped,
|
||
|
}, {
|
||
|
signal: options?.signal,
|
||
|
});
|
||
|
if ("usage" in data && data.usage) {
|
||
|
const { completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens, } = data.usage;
|
||
|
if (completionTokens) {
|
||
|
tokenUsage.completionTokens =
|
||
|
(tokenUsage.completionTokens ?? 0) + completionTokens;
|
||
|
}
|
||
|
if (promptTokens) {
|
||
|
tokenUsage.promptTokens =
|
||
|
(tokenUsage.promptTokens ?? 0) + promptTokens;
|
||
|
}
|
||
|
if (totalTokens) {
|
||
|
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
|
||
|
}
|
||
|
}
|
||
|
const generations = [];
|
||
|
if ("choices" in data && data.choices) {
|
||
|
for (const part of data
|
||
|
.choices) {
|
||
|
const text = part.message?.content ?? "";
|
||
|
const generation = {
|
||
|
text,
|
||
|
message: premResponseToChatMessage(part.message ?? { role: "assistant" }),
|
||
|
};
|
||
|
generation.generationInfo = {
|
||
|
...(part.finish_reason
|
||
|
? { finish_reason: part.finish_reason }
|
||
|
: {}),
|
||
|
...(part.logprobs ? { logprobs: part.logprobs } : {}),
|
||
|
};
|
||
|
generations.push(generation);
|
||
|
}
|
||
|
}
|
||
|
return {
|
||
|
generations,
|
||
|
llmOutput: { tokenUsage },
|
||
|
};
|
||
|
}
|
||
|
}
|
||
|
}
|