import { InvokeEndpointCommand, InvokeEndpointWithResponseStreamCommand, SageMakerRuntimeClient, } from "@aws-sdk/client-sagemaker-runtime"; import { GenerationChunk } from "@langchain/core/outputs"; import { LLM, } from "@langchain/core/language_models/llms"; /** * A handler class to transform input from LLM to a format that SageMaker * endpoint expects. Similarily, the class also handles transforming output from * the SageMaker endpoint to a format that LLM class expects. * * Example: * ``` * class ContentHandler implements ContentHandlerBase { * contentType = "application/json" * accepts = "application/json" * * transformInput(prompt: string, modelKwargs: Record) { * const inputString = JSON.stringify({ * prompt, * ...modelKwargs * }) * return Buffer.from(inputString) * } * * transformOutput(output: Uint8Array) { * const responseJson = JSON.parse(Buffer.from(output).toString("utf-8")) * return responseJson[0].generated_text * } * * } * ``` */ export class BaseSageMakerContentHandler { constructor() { Object.defineProperty(this, "contentType", { enumerable: true, configurable: true, writable: true, value: "text/plain" }); Object.defineProperty(this, "accepts", { enumerable: true, configurable: true, writable: true, value: "text/plain" }); } } /** * The SageMakerEndpoint class is used to interact with SageMaker * Inference Endpoint models. It uses the AWS client for authentication, * which automatically loads credentials. * If a specific credential profile is to be used, the name of the profile * from the ~/.aws/credentials file must be passed. The credentials or * roles used should have the required policies to access the SageMaker * endpoint. */ export class SageMakerEndpoint extends LLM { static lc_name() { return "SageMakerEndpoint"; } get lc_secrets() { return { "clientOptions.credentials.accessKeyId": "AWS_ACCESS_KEY_ID", "clientOptions.credentials.secretAccessKey": "AWS_SECRET_ACCESS_KEY", "clientOptions.credentials.sessionToken": "AWS_SESSION_TOKEN", }; } constructor(fields) { super(fields); Object.defineProperty(this, "lc_serializable", { enumerable: true, configurable: true, writable: true, value: true }); Object.defineProperty(this, "endpointName", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "modelKwargs", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "endpointKwargs", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "client", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "contentHandler", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "streaming", { enumerable: true, configurable: true, writable: true, value: void 0 }); if (!fields.clientOptions.region) { throw new Error(`Please pass a "clientOptions" object with a "region" field to the constructor`); } const endpointName = fields?.endpointName; if (!endpointName) { throw new Error(`Please pass an "endpointName" field to the constructor`); } const contentHandler = fields?.contentHandler; if (!contentHandler) { throw new Error(`Please pass a "contentHandler" field to the constructor`); } this.endpointName = fields.endpointName; this.contentHandler = fields.contentHandler; this.endpointKwargs = fields.endpointKwargs; this.modelKwargs = fields.modelKwargs; this.streaming = fields.streaming ?? false; this.client = new SageMakerRuntimeClient(fields.clientOptions); } _llmType() { return "sagemaker_endpoint"; } /** * Calls the SageMaker endpoint and retrieves the result. * @param {string} prompt The input prompt. * @param {this["ParsedCallOptions"]} options Parsed call options. * @param {CallbackManagerForLLMRun} runManager Optional run manager. * @returns {Promise} A promise that resolves to the generated string. */ /** @ignore */ async _call(prompt, options, runManager) { return this.streaming ? await this.streamingCall(prompt, options, runManager) : await this.noStreamingCall(prompt, options); } async streamingCall(prompt, options, runManager) { const chunks = []; for await (const chunk of this._streamResponseChunks(prompt, options, runManager)) { chunks.push(chunk.text); } return chunks.join(""); } async noStreamingCall(prompt, options) { const body = await this.contentHandler.transformInput(prompt, this.modelKwargs ?? {}); const { contentType, accepts } = this.contentHandler; const response = await => this.client.send(new InvokeEndpointCommand({ EndpointName: this.endpointName, Body: body, ContentType: contentType, Accept: accepts, ...this.endpointKwargs, }), { abortSignal: options.signal })); if (response.Body === undefined) { throw new Error("Inference result missing Body"); } return this.contentHandler.transformOutput(response.Body); } /** * Streams response chunks from the SageMaker endpoint. * @param {string} prompt The input prompt. * @param {this["ParsedCallOptions"]} options Parsed call options. * @returns {AsyncGenerator} An asynchronous generator yielding generation chunks. */ async *_streamResponseChunks(prompt, options, runManager) { const body = await this.contentHandler.transformInput(prompt, this.modelKwargs ?? {}); const { contentType, accepts } = this.contentHandler; const stream = await => this.client.send(new InvokeEndpointWithResponseStreamCommand({ EndpointName: this.endpointName, Body: body, ContentType: contentType, Accept: accepts, ...this.endpointKwargs, }), { abortSignal: options.signal })); if (!stream.Body) { throw new Error("Inference result missing Body"); } for await (const chunk of stream.Body) { if (chunk.PayloadPart && chunk.PayloadPart.Bytes) { const text = await this.contentHandler.transformOutput(chunk.PayloadPart.Bytes); yield new GenerationChunk({ text, generationInfo: { ...chunk, response: undefined, }, }); await runManager?.handleLLMNewToken(text); } else if (chunk.InternalStreamFailure) { throw new Error(chunk.InternalStreamFailure.message); } else if (chunk.ModelStreamError) { throw new Error(chunk.ModelStreamError.message); } } } }