"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.Bedrock = void 0;
const signature_v4_1 = require("@smithy/signature-v4");
const protocol_http_1 = require("@smithy/protocol-http");
const eventstream_codec_1 = require("@smithy/eventstream-codec");
const util_utf8_1 = require("@smithy/util-utf8");
const sha256_js_1 = require("@aws-crypto/sha256-js");
const env_1 = require("@langchain/core/utils/env");
const outputs_1 = require("@langchain/core/outputs");
const llms_1 = require("@langchain/core/language_models/llms");
const index_js_1 = require("../../utils/bedrock/index.cjs");
const AWS_REGIONS = [
    "us",
    "sa",
    "me",
    "il",
    "eu",
    "cn",
    "ca",
    "ap",
    "af",
    "us-gov",
];
const ALLOWED_MODEL_PROVIDERS = [
    "ai21",
    "anthropic",
    "amazon",
    "cohere",
    "meta",
    "mistral",
];
const PRELUDE_TOTAL_LENGTH_BYTES = 4;
/**
 * A type of Large Language Model (LLM) that interacts with the Bedrock
 * service. It extends the base `LLM` class and implements the
 * `BaseBedrockInput` interface. The class is designed to authenticate and
 * interact with the Bedrock service, which is a part of Amazon Web
 * Services (AWS). It uses AWS credentials for authentication and can be
 * configured with various parameters such as the model to use, the AWS
 * region, and the maximum number of tokens to generate.
 */
class Bedrock extends llms_1.LLM {
    get lc_aliases() {
        return {
            model: "model_id",
            region: "region_name",
        };
    }
    get lc_secrets() {
        return {
            "credentials.accessKeyId": "BEDROCK_AWS_ACCESS_KEY_ID",
            "credentials.secretAccessKey": "BEDROCK_AWS_SECRET_ACCESS_KEY",
        };
    }
    get lc_attributes() {
        return { region: this.region };
    }
    _llmType() {
        return "bedrock";
    }
    static lc_name() {
        return "Bedrock";
    }
    constructor(fields) {
        super(fields ?? {});
        Object.defineProperty(this, "model", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: "amazon.titan-tg1-large"
        });
        Object.defineProperty(this, "modelProvider", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "region", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "credentials", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "temperature", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: undefined
        });
        Object.defineProperty(this, "maxTokens", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: undefined
        });
        Object.defineProperty(this, "fetchFn", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "endpointHost", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        /** @deprecated */
        Object.defineProperty(this, "stopSequences", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "modelKwargs", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: void 0
        });
        Object.defineProperty(this, "codec", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: new eventstream_codec_1.EventStreamCodec(util_utf8_1.toUtf8, util_utf8_1.fromUtf8)
        });
        Object.defineProperty(this, "streaming", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: false
        });
        Object.defineProperty(this, "lc_serializable", {
            enumerable: true,
            configurable: true,
            writable: true,
            value: true
        });
        this.model = fields?.model ?? this.model;
        this.modelProvider = getModelProvider(this.model);
        if (!ALLOWED_MODEL_PROVIDERS.includes(this.modelProvider)) {
            throw new Error(`Unknown model provider: '${this.modelProvider}', only these are supported: ${ALLOWED_MODEL_PROVIDERS}`);
        }
        const region = fields?.region ?? (0, env_1.getEnvironmentVariable)("AWS_DEFAULT_REGION");
        if (!region) {
            throw new Error("Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field.");
        }
        this.region = region;
        const credentials = fields?.credentials;
        if (!credentials) {
            throw new Error("Please set the AWS credentials in the 'credentials' field.");
        }
        this.credentials = credentials;
        this.temperature = fields?.temperature ?? this.temperature;
        this.maxTokens = fields?.maxTokens ?? this.maxTokens;
        this.fetchFn = fields?.fetchFn ?? fetch.bind(globalThis);
        this.endpointHost = fields?.endpointHost ?? fields?.endpointUrl;
        this.stopSequences = fields?.stopSequences;
        this.modelKwargs = fields?.modelKwargs;
        this.streaming = fields?.streaming ?? this.streaming;
    }
    /** Call out to Bedrock service model.
      Arguments:
        prompt: The prompt to pass into the model.
  
      Returns:
        The string generated by the model.
  
      Example:
        response = model.invoke("Tell me a joke.")
    */
    async _call(prompt, options, runManager) {
        const service = "bedrock-runtime";
        const endpointHost = this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
        const provider = this.modelProvider;
        if (this.streaming) {
            const stream = this._streamResponseChunks(prompt, options, runManager);
            let finalResult;
            for await (const chunk of stream) {
                if (finalResult === undefined) {
                    finalResult = chunk;
                }
                else {
                    finalResult = finalResult.concat(chunk);
                }
            }
            return finalResult?.text ?? "";
        }
        const response = await this._signedFetch(prompt, options, {
            bedrockMethod: "invoke",
            endpointHost,
            provider,
        });
        const json = await response.json();
        if (!response.ok) {
            throw new Error(`Error ${response.status}: ${json.message ?? JSON.stringify(json)}`);
        }
        const text = index_js_1.BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
        return text;
    }
    async _signedFetch(prompt, options, fields) {
        const { bedrockMethod, endpointHost, provider } = fields;
        const inputBody = index_js_1.BedrockLLMInputOutputAdapter.prepareInput(provider, prompt, this.maxTokens, this.temperature, options.stop ?? this.stopSequences, this.modelKwargs, fields.bedrockMethod);
        const url = new URL(`https://${endpointHost}/model/${this.model}/${bedrockMethod}`);
        const request = new protocol_http_1.HttpRequest({
            hostname: url.hostname,
            path: url.pathname,
            protocol: url.protocol,
            method: "POST",
            body: JSON.stringify(inputBody),
            query: Object.fromEntries(url.searchParams.entries()),
            headers: {
                // host is required by AWS Signature V4: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
                host: url.host,
                accept: "application/json",
                "content-type": "application/json",
            },
        });
        const signer = new signature_v4_1.SignatureV4({
            credentials: this.credentials,
            service: "bedrock",
            region: this.region,
            sha256: sha256_js_1.Sha256,
        });
        const signedRequest = await signer.sign(request);
        // Send request to AWS using the low-level fetch API
        const response = await this.caller.callWithOptions({ signal: options.signal }, async () => this.fetchFn(url, {
            headers: signedRequest.headers,
            body: signedRequest.body,
            method: signedRequest.method,
        }));
        return response;
    }
    invocationParams(options) {
        return {
            model: this.model,
            region: this.region,
            temperature: this.temperature,
            maxTokens: this.maxTokens,
            stop: options?.stop ?? this.stopSequences,
            modelKwargs: this.modelKwargs,
        };
    }
    async *_streamResponseChunks(prompt, options, runManager) {
        const provider = this.modelProvider;
        const bedrockMethod = provider === "anthropic" ||
            provider === "cohere" ||
            provider === "meta" ||
            provider === "mistral"
            ? "invoke-with-response-stream"
            : "invoke";
        const service = "bedrock-runtime";
        const endpointHost = this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
        // Send request to AWS using the low-level fetch API
        const response = await this._signedFetch(prompt, options, {
            bedrockMethod,
            endpointHost,
            provider,
        });
        if (response.status < 200 || response.status >= 300) {
            throw Error(`Failed to access underlying url '${endpointHost}': got ${response.status} ${response.statusText}: ${await response.text()}`);
        }
        if (provider === "anthropic" ||
            provider === "cohere" ||
            provider === "meta" ||
            provider === "mistral") {
            const reader = response.body?.getReader();
            const decoder = new TextDecoder();
            for await (const chunk of this._readChunks(reader)) {
                const event = this.codec.decode(chunk);
                if ((event.headers[":event-type"] !== undefined &&
                    event.headers[":event-type"].value !== "chunk") ||
                    event.headers[":content-type"].value !== "application/json") {
                    throw Error(`Failed to get event chunk: got ${chunk}`);
                }
                const body = JSON.parse(decoder.decode(event.body));
                if (body.message) {
                    throw new Error(body.message);
                }
                if (body.bytes !== undefined) {
                    const chunkResult = JSON.parse(decoder.decode(Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0)));
                    const text = index_js_1.BedrockLLMInputOutputAdapter.prepareOutput(provider, chunkResult);
                    yield new outputs_1.GenerationChunk({
                        text,
                        generationInfo: {},
                    });
                    // eslint-disable-next-line no-void
                    void runManager?.handleLLMNewToken(text);
                }
            }
        }
        else {
            const json = await response.json();
            const text = index_js_1.BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
            yield new outputs_1.GenerationChunk({
                text,
                generationInfo: {},
            });
            // eslint-disable-next-line no-void
            void runManager?.handleLLMNewToken(text);
        }
    }
    // eslint-disable-next-line @typescript-eslint/no-explicit-any
    _readChunks(reader) {
        function _concatChunks(a, b) {
            const newBuffer = new Uint8Array(a.length + b.length);
            newBuffer.set(a);
            newBuffer.set(b, a.length);
            return newBuffer;
        }
        function getMessageLength(buffer) {
            if (buffer.byteLength < PRELUDE_TOTAL_LENGTH_BYTES)
                return 0;
            const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
            return view.getUint32(0, false);
        }
        return {
            async *[Symbol.asyncIterator]() {
                let readResult = await reader.read();
                let buffer = new Uint8Array(0);
                while (!readResult.done) {
                    const chunk = readResult.value;
                    buffer = _concatChunks(buffer, chunk);
                    let messageLength = getMessageLength(buffer);
                    while (buffer.byteLength >= PRELUDE_TOTAL_LENGTH_BYTES &&
                        buffer.byteLength >= messageLength) {
                        yield buffer.slice(0, messageLength);
                        buffer = buffer.slice(messageLength);
                        messageLength = getMessageLength(buffer);
                    }
                    readResult = await reader.read();
                }
            },
        };
    }
}
exports.Bedrock = Bedrock;
function isInferenceModel(modelId) {
    const parts = modelId.split(".");
    return AWS_REGIONS.some((region) => parts[0] === region);
}
function getModelProvider(modelId) {
    const parts = modelId.split(".");
    if (isInferenceModel(modelId)) {
        return parts[1];
    }
    else {
        return parts[0];
    }
}