"use strict";
var __importDefault = (this && this.__importDefault) || function (mod) {
    return (mod && mod.__esModule) ? mod : { "default": mod };
};
Object.defineProperty(exports, "__esModule", { value: true });
exports.ChatPrem = exports.messageToPremRole = void 0;
const messages_1 = require("@langchain/core/messages");
const chat_models_1 = require("@langchain/core/language_models/chat_models");
const prem_sdk_1 = __importDefault(require("@premai/prem-sdk"));
const env_1 = require("@langchain/core/utils/env");
const outputs_1 = require("@langchain/core/outputs");
function extractGenericMessageCustomRole(message) {
    if (message.role !== "assistant" && message.role !== "user") {
        console.warn(`Unknown message role: ${message.role}`);
    }
    return message.role;
}
function messageToPremRole(message) {
    const type = message._getType();
    switch (type) {
        case "ai":
            return "assistant";
        case "human":
            return "user";
        case "generic": {
            if (!messages_1.ChatMessage.isInstance(message))
                throw new Error("Invalid generic chat message");
            return extractGenericMessageCustomRole(message);
        }
        default:
            throw new Error(`Unknown message type: ${type}`);
    }
}
exports.messageToPremRole = messageToPremRole;
function convertMessagesToPremParams(messages) {
    return messages.map((message) => {
        if (typeof message.content !== "string") {
            throw new Error("Non string message content not supported");
        }
        return {
            role: messageToPremRole(message),
            content: message.content,
            name: message.name,
            function_call: message.additional_kwargs.function_call,
        };
    });
}
function premResponseToChatMessage(message) {
    switch (message.role) {
        case "assistant":
            return new messages_1.AIMessage(message.content || "");
        default:
            return new messages_1.ChatMessage(message.content || "", message.role ?? "unknown");
    }
}
function _convertDeltaToMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta) {
    const { role } = delta;
    const content = delta.content ?? "";
    let additional_kwargs;
    if (delta.function_call) {
        additional_kwargs = {
            function_call: delta.function_call,
        };
    }
    else {
        additional_kwargs = {};
    }
    if (role === "user") {
        return new messages_1.HumanMessageChunk({ content });
    }
    else if (role === "assistant") {
        return new messages_1.AIMessageChunk({ content, additional_kwargs });
    }
    else {
        return new messages_1.ChatMessageChunk({ content, role });
    }
}
/**
 * Integration with a chat model.
 */
class ChatPrem extends chat_models_1.BaseChatModel {
    // Used for tracing, replace with the same name as your class
    static lc_name() {
        return "ChatPrem";
    }
    /**
     * Replace with any secrets this class passes to `super`.
     * See {@link ../../langchain-cohere/src/chat_model.ts} for
     * an example.
     */
    get lc_secrets() {
        return {
            apiKey: "PREM_API_KEY",
        };
    }
    get lc_aliases() {
        return {
            apiKey: "PREM_API_KEY",
        };
    }
    constructor(fields) {
        super(fields ?? {});
        Object.defineProperty(this, "client", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "apiKey", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "project_id", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "session_id", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "messages", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "model", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "system_prompt", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "frequency_penalty", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "logit_bias", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "max_tokens", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "n", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "presence_penalty", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "response_format", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "seed", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "stop", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "temperature", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "top_p", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "tools", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "user", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "streaming", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: false
        });
        Object.defineProperty(this, "lc_serializable", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: true
        });
        const apiKey = fields?.apiKey ?? (0, env_1.getEnvironmentVariable)("PREM_API_KEY");
        if (!apiKey) {
            throw new Error(`Prem API key not found. Please set the PREM_API_KEY environment variable or provide the key into "apiKey"`);
        }
        const projectId = fields?.project_id ??
            parseInt((0, env_1.getEnvironmentVariable)("PREM_PROJECT_ID") ?? "-1", 10);
        if (!projectId || projectId === -1 || typeof projectId !== "number") {
            throw new Error(`Prem project ID not found. Please set the PREM_PROJECT_ID environment variable or provide the key into "project_id"`);
        }
        this.client = new prem_sdk_1.default({
            apiKey,
        });
        this.project_id = projectId;
        this.session_id = fields?.session_id ?? this.session_id;
        this.messages = fields?.messages ?? this.messages;
        this.model = fields?.model ?? this.model;
        this.system_prompt = fields?.system_prompt ?? this.system_prompt;
        this.frequency_penalty =
            fields?.frequency_penalty ?? this.frequency_penalty;
        this.logit_bias = fields?.logit_bias ?? this.logit_bias;
        this.max_tokens = fields?.max_tokens ?? this.max_tokens;
        this.n = fields?.n ?? this.n;
        this.presence_penalty = fields?.presence_penalty ?? this.presence_penalty;
        this.response_format = fields?.response_format ?? this.response_format;
        this.seed = fields?.seed ?? this.seed;
        this.stop = fields?.stop ?? this.stop;
        this.temperature = fields?.temperature ?? this.temperature;
        this.top_p = fields?.top_p ?? this.top_p;
        this.tools = fields?.tools ?? this.tools;
        this.user = fields?.user ?? this.user;
        this.streaming = fields?.streaming ?? this.streaming;
    }
    // Replace
    _llmType() {
        return "prem";
    }
    async completionWithRetry(request, 
    // eslint-disable-next-line @typescript-eslint/no-explicit-any
    options) {
        return this.caller.call(async () => this.client.chat.completions.create(request, options));
    }
    invocationParams(options) {
        const params = super.invocationParams(options);
        return {
            ...params,
            project_id: this.project_id,
            session_id: this.session_id,
            messages: this.messages,
            model: this.model,
            system_prompt: this.system_prompt,
            frequency_penalty: this.frequency_penalty,
            logit_bias: this.logit_bias,
            max_tokens: this.max_tokens,
            n: this.n,
            presence_penalty: this.presence_penalty,
            response_format: this.response_format,
            seed: this.seed,
            stop: this.stop,
            temperature: this.temperature,
            top_p: this.top_p,
            tools: this.tools,
            user: this.user,
            streaming: this.streaming,
            stream: this.streaming,
        };
    }
    /**
     * Implement to support streaming.
     * Should yield chunks iteratively.
     */
    async *_streamResponseChunks(messages, options, runManager) {
        const params = this.invocationParams(options);
        const messagesMapped = convertMessagesToPremParams(messages);
        // All models have a built-in `this.caller` property for retries
        const stream = await this.caller.call(async () => this.completionWithRetry({
            ...params,
            messages: messagesMapped,
            stream: true,
        }, params));
        for await (const data of stream) {
            const choice = data?.choices[0];
            if (!choice) {
                continue;
            }
            const chunk = new outputs_1.ChatGenerationChunk({
                message: _convertDeltaToMessageChunk(choice.delta ?? {}),
                text: choice.delta.content ?? "",
                generationInfo: {
                    finishReason: choice.finish_reason,
                },
            });
            yield chunk;
            void runManager?.handleLLMNewToken(chunk.text ?? "");
        }
        if (options.signal?.aborted) {
            throw new Error("AbortError");
        }
    }
    /** @ignore */
    _combineLLMOutput() {
        return [];
    }
    async _generate(messages, options, runManager) {
        const tokenUsage = {};
        const params = this.invocationParams(options);
        const messagesMapped = convertMessagesToPremParams(messages);
        if (params.streaming) {
            const stream = this._streamResponseChunks(messages, options, runManager);
            const finalChunks = {};
            for await (const chunk of stream) {
                const index = chunk.generationInfo?.completion ?? 0;
                if (finalChunks[index] === undefined) {
                    finalChunks[index] = chunk;
                }
                else {
                    finalChunks[index] = finalChunks[index].concat(chunk);
                }
            }
            const generations = Object.entries(finalChunks)
                .sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
                .map(([_, value]) => value);
            return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } };
        }
        else {
            const data = await this.completionWithRetry({
                ...params,
                stream: false,
                messages: messagesMapped,
            }, {
                signal: options?.signal,
            });
            if ("usage" in data && data.usage) {
                const { completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens, } = data.usage;
                if (completionTokens) {
                    tokenUsage.completionTokens =
                        (tokenUsage.completionTokens ?? 0) + completionTokens;
                }
                if (promptTokens) {
                    tokenUsage.promptTokens =
                        (tokenUsage.promptTokens ?? 0) + promptTokens;
                }
                if (totalTokens) {
                    tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
                }
            }
            const generations = [];
            if ("choices" in data && data.choices) {
                for (const part of data
                    .choices) {
                    const text = part.message?.content ?? "";
                    const generation = {
                        text,
                        message: premResponseToChatMessage(part.message ?? { role: "assistant" }),
                    };
                    generation.generationInfo = {
                        ...(part.finish_reason
                            ? { finish_reason: part.finish_reason }
                            : {}),
                        ...(part.logprobs ? { logprobs: part.logprobs } : {}),
                    };
                    generations.push(generation);
                }
            }
            return {
                generations,
                llmOutput: { tokenUsage },
            };
        }
    }
}
exports.ChatPrem = ChatPrem;