agsamantha/node_modules/@langchain/community/dist/llms/bedrock/web.js
2024-10-02 15:15:21 -05:00

353 lines
13 KiB
JavaScript

import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
import { EventStreamCodec } from "@smithy/eventstream-codec";
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
import { Sha256 } from "@aws-crypto/sha256-js";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { GenerationChunk } from "@langchain/core/outputs";
import { LLM } from "@langchain/core/language_models/llms";
import { BedrockLLMInputOutputAdapter, } from "../../utils/bedrock/index.js";
const AWS_REGIONS = [
"us",
"sa",
"me",
"il",
"eu",
"cn",
"ca",
"ap",
"af",
"us-gov",
];
const ALLOWED_MODEL_PROVIDERS = [
"ai21",
"anthropic",
"amazon",
"cohere",
"meta",
"mistral",
];
const PRELUDE_TOTAL_LENGTH_BYTES = 4;
/**
* A type of Large Language Model (LLM) that interacts with the Bedrock
* service. It extends the base `LLM` class and implements the
* `BaseBedrockInput` interface. The class is designed to authenticate and
* interact with the Bedrock service, which is a part of Amazon Web
* Services (AWS). It uses AWS credentials for authentication and can be
* configured with various parameters such as the model to use, the AWS
* region, and the maximum number of tokens to generate.
*/
export class Bedrock extends LLM {
get lc_aliases() {
return {
model: "model_id",
region: "region_name",
};
}
get lc_secrets() {
return {
"credentials.accessKeyId": "BEDROCK_AWS_ACCESS_KEY_ID",
"credentials.secretAccessKey": "BEDROCK_AWS_SECRET_ACCESS_KEY",
};
}
get lc_attributes() {
return { region: this.region };
}
_llmType() {
return "bedrock";
}
static lc_name() {
return "Bedrock";
}
constructor(fields) {
super(fields ?? {});
Object.defineProperty(this, "model", {
enumerable: true,
configurable: true,
writable: true,
value: "amazon.titan-tg1-large"
});
Object.defineProperty(this, "modelProvider", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "region", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "credentials", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "temperature", {
enumerable: true,
configurable: true,
writable: true,
value: undefined
});
Object.defineProperty(this, "maxTokens", {
enumerable: true,
configurable: true,
writable: true,
value: undefined
});
Object.defineProperty(this, "fetchFn", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "endpointHost", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/** @deprecated */
Object.defineProperty(this, "stopSequences", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "modelKwargs", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "codec", {
enumerable: true,
configurable: true,
writable: true,
value: new EventStreamCodec(toUtf8, fromUtf8)
});
Object.defineProperty(this, "streaming", {
enumerable: true,
configurable: true,
writable: true,
value: false
});
Object.defineProperty(this, "lc_serializable", {
enumerable: true,
configurable: true,
writable: true,
value: true
});
this.model = fields?.model ?? this.model;
this.modelProvider = getModelProvider(this.model);
if (!ALLOWED_MODEL_PROVIDERS.includes(this.modelProvider)) {
throw new Error(`Unknown model provider: '${this.modelProvider}', only these are supported: ${ALLOWED_MODEL_PROVIDERS}`);
}
const region = fields?.region ?? getEnvironmentVariable("AWS_DEFAULT_REGION");
if (!region) {
throw new Error("Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field.");
}
this.region = region;
const credentials = fields?.credentials;
if (!credentials) {
throw new Error("Please set the AWS credentials in the 'credentials' field.");
}
this.credentials = credentials;
this.temperature = fields?.temperature ?? this.temperature;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.fetchFn = fields?.fetchFn ?? fetch.bind(globalThis);
this.endpointHost = fields?.endpointHost ?? fields?.endpointUrl;
this.stopSequences = fields?.stopSequences;
this.modelKwargs = fields?.modelKwargs;
this.streaming = fields?.streaming ?? this.streaming;
}
/** Call out to Bedrock service model.
Arguments:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
response = model.invoke("Tell me a joke.")
*/
async _call(prompt, options, runManager) {
const service = "bedrock-runtime";
const endpointHost = this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
const provider = this.modelProvider;
if (this.streaming) {
const stream = this._streamResponseChunks(prompt, options, runManager);
let finalResult;
for await (const chunk of stream) {
if (finalResult === undefined) {
finalResult = chunk;
}
else {
finalResult = finalResult.concat(chunk);
}
}
return finalResult?.text ?? "";
}
const response = await this._signedFetch(prompt, options, {
bedrockMethod: "invoke",
endpointHost,
provider,
});
const json = await response.json();
if (!response.ok) {
throw new Error(`Error ${response.status}: ${json.message ?? JSON.stringify(json)}`);
}
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
return text;
}
async _signedFetch(prompt, options, fields) {
const { bedrockMethod, endpointHost, provider } = fields;
const inputBody = BedrockLLMInputOutputAdapter.prepareInput(provider, prompt, this.maxTokens, this.temperature, options.stop ?? this.stopSequences, this.modelKwargs, fields.bedrockMethod);
const url = new URL(`https://${endpointHost}/model/${this.model}/${bedrockMethod}`);
const request = new HttpRequest({
hostname: url.hostname,
path: url.pathname,
protocol: url.protocol,
method: "POST",
body: JSON.stringify(inputBody),
query: Object.fromEntries(url.searchParams.entries()),
headers: {
// host is required by AWS Signature V4: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
host: url.host,
accept: "application/json",
"content-type": "application/json",
},
});
const signer = new SignatureV4({
credentials: this.credentials,
service: "bedrock",
region: this.region,
sha256: Sha256,
});
const signedRequest = await signer.sign(request);
// Send request to AWS using the low-level fetch API
const response = await this.caller.callWithOptions({ signal: options.signal }, async () => this.fetchFn(url, {
headers: signedRequest.headers,
body: signedRequest.body,
method: signedRequest.method,
}));
return response;
}
invocationParams(options) {
return {
model: this.model,
region: this.region,
temperature: this.temperature,
maxTokens: this.maxTokens,
stop: options?.stop ?? this.stopSequences,
modelKwargs: this.modelKwargs,
};
}
async *_streamResponseChunks(prompt, options, runManager) {
const provider = this.modelProvider;
const bedrockMethod = provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral"
? "invoke-with-response-stream"
: "invoke";
const service = "bedrock-runtime";
const endpointHost = this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
// Send request to AWS using the low-level fetch API
const response = await this._signedFetch(prompt, options, {
bedrockMethod,
endpointHost,
provider,
});
if (response.status < 200 || response.status >= 300) {
throw Error(`Failed to access underlying url '${endpointHost}': got ${response.status} ${response.statusText}: ${await response.text()}`);
}
if (provider === "anthropic" ||
provider === "cohere" ||
provider === "meta" ||
provider === "mistral") {
const reader = response.body?.getReader();
const decoder = new TextDecoder();
for await (const chunk of this._readChunks(reader)) {
const event = this.codec.decode(chunk);
if ((event.headers[":event-type"] !== undefined &&
event.headers[":event-type"].value !== "chunk") ||
event.headers[":content-type"].value !== "application/json") {
throw Error(`Failed to get event chunk: got ${chunk}`);
}
const body = JSON.parse(decoder.decode(event.body));
if (body.message) {
throw new Error(body.message);
}
if (body.bytes !== undefined) {
const chunkResult = JSON.parse(decoder.decode(Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0)));
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, chunkResult);
yield new GenerationChunk({
text,
generationInfo: {},
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
}
}
}
else {
const json = await response.json();
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
yield new GenerationChunk({
text,
generationInfo: {},
});
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(text);
}
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
_readChunks(reader) {
function _concatChunks(a, b) {
const newBuffer = new Uint8Array(a.length + b.length);
newBuffer.set(a);
newBuffer.set(b, a.length);
return newBuffer;
}
function getMessageLength(buffer) {
if (buffer.byteLength < PRELUDE_TOTAL_LENGTH_BYTES)
return 0;
const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
return view.getUint32(0, false);
}
return {
async *[Symbol.asyncIterator]() {
let readResult = await reader.read();
let buffer = new Uint8Array(0);
while (!readResult.done) {
const chunk = readResult.value;
buffer = _concatChunks(buffer, chunk);
let messageLength = getMessageLength(buffer);
while (buffer.byteLength >= PRELUDE_TOTAL_LENGTH_BYTES &&
buffer.byteLength >= messageLength) {
yield buffer.slice(0, messageLength);
buffer = buffer.slice(messageLength);
messageLength = getMessageLength(buffer);
}
readResult = await reader.read();
}
},
};
}
}
function isInferenceModel(modelId) {
const parts = modelId.split(".");
return AWS_REGIONS.some((region) => parts[0] === region);
}
function getModelProvider(modelId) {
const parts = modelId.split(".");
if (isInferenceModel(modelId)) {
return parts[1];
}
else {
return parts[0];
}
}