353 lines
13 KiB
JavaScript
353 lines
13 KiB
JavaScript
import { SignatureV4 } from "@smithy/signature-v4";
|
|
import { HttpRequest } from "@smithy/protocol-http";
|
|
import { EventStreamCodec } from "@smithy/eventstream-codec";
|
|
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
|
|
import { Sha256 } from "@aws-crypto/sha256-js";
|
|
import { getEnvironmentVariable } from "@langchain/core/utils/env";
|
|
import { GenerationChunk } from "@langchain/core/outputs";
|
|
import { LLM } from "@langchain/core/language_models/llms";
|
|
import { BedrockLLMInputOutputAdapter, } from "../../utils/bedrock/index.js";
|
|
const AWS_REGIONS = [
|
|
"us",
|
|
"sa",
|
|
"me",
|
|
"il",
|
|
"eu",
|
|
"cn",
|
|
"ca",
|
|
"ap",
|
|
"af",
|
|
"us-gov",
|
|
];
|
|
const ALLOWED_MODEL_PROVIDERS = [
|
|
"ai21",
|
|
"anthropic",
|
|
"amazon",
|
|
"cohere",
|
|
"meta",
|
|
"mistral",
|
|
];
|
|
const PRELUDE_TOTAL_LENGTH_BYTES = 4;
|
|
/**
|
|
* A type of Large Language Model (LLM) that interacts with the Bedrock
|
|
* service. It extends the base `LLM` class and implements the
|
|
* `BaseBedrockInput` interface. The class is designed to authenticate and
|
|
* interact with the Bedrock service, which is a part of Amazon Web
|
|
* Services (AWS). It uses AWS credentials for authentication and can be
|
|
* configured with various parameters such as the model to use, the AWS
|
|
* region, and the maximum number of tokens to generate.
|
|
*/
|
|
export class Bedrock extends LLM {
|
|
get lc_aliases() {
|
|
return {
|
|
model: "model_id",
|
|
region: "region_name",
|
|
};
|
|
}
|
|
get lc_secrets() {
|
|
return {
|
|
"credentials.accessKeyId": "BEDROCK_AWS_ACCESS_KEY_ID",
|
|
"credentials.secretAccessKey": "BEDROCK_AWS_SECRET_ACCESS_KEY",
|
|
};
|
|
}
|
|
get lc_attributes() {
|
|
return { region: this.region };
|
|
}
|
|
_llmType() {
|
|
return "bedrock";
|
|
}
|
|
static lc_name() {
|
|
return "Bedrock";
|
|
}
|
|
constructor(fields) {
|
|
super(fields ?? {});
|
|
Object.defineProperty(this, "model", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "amazon.titan-tg1-large"
|
|
});
|
|
Object.defineProperty(this, "modelProvider", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "region", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "credentials", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "temperature", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "maxTokens", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "fetchFn", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "endpointHost", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
/** @deprecated */
|
|
Object.defineProperty(this, "stopSequences", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "modelKwargs", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "codec", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: new EventStreamCodec(toUtf8, fromUtf8)
|
|
});
|
|
Object.defineProperty(this, "streaming", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: false
|
|
});
|
|
Object.defineProperty(this, "lc_serializable", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: true
|
|
});
|
|
this.model = fields?.model ?? this.model;
|
|
this.modelProvider = getModelProvider(this.model);
|
|
if (!ALLOWED_MODEL_PROVIDERS.includes(this.modelProvider)) {
|
|
throw new Error(`Unknown model provider: '${this.modelProvider}', only these are supported: ${ALLOWED_MODEL_PROVIDERS}`);
|
|
}
|
|
const region = fields?.region ?? getEnvironmentVariable("AWS_DEFAULT_REGION");
|
|
if (!region) {
|
|
throw new Error("Please set the AWS_DEFAULT_REGION environment variable or pass it to the constructor as the region field.");
|
|
}
|
|
this.region = region;
|
|
const credentials = fields?.credentials;
|
|
if (!credentials) {
|
|
throw new Error("Please set the AWS credentials in the 'credentials' field.");
|
|
}
|
|
this.credentials = credentials;
|
|
this.temperature = fields?.temperature ?? this.temperature;
|
|
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
|
|
this.fetchFn = fields?.fetchFn ?? fetch.bind(globalThis);
|
|
this.endpointHost = fields?.endpointHost ?? fields?.endpointUrl;
|
|
this.stopSequences = fields?.stopSequences;
|
|
this.modelKwargs = fields?.modelKwargs;
|
|
this.streaming = fields?.streaming ?? this.streaming;
|
|
}
|
|
/** Call out to Bedrock service model.
|
|
Arguments:
|
|
prompt: The prompt to pass into the model.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
|
|
Example:
|
|
response = model.invoke("Tell me a joke.")
|
|
*/
|
|
async _call(prompt, options, runManager) {
|
|
const service = "bedrock-runtime";
|
|
const endpointHost = this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
|
|
const provider = this.modelProvider;
|
|
if (this.streaming) {
|
|
const stream = this._streamResponseChunks(prompt, options, runManager);
|
|
let finalResult;
|
|
for await (const chunk of stream) {
|
|
if (finalResult === undefined) {
|
|
finalResult = chunk;
|
|
}
|
|
else {
|
|
finalResult = finalResult.concat(chunk);
|
|
}
|
|
}
|
|
return finalResult?.text ?? "";
|
|
}
|
|
const response = await this._signedFetch(prompt, options, {
|
|
bedrockMethod: "invoke",
|
|
endpointHost,
|
|
provider,
|
|
});
|
|
const json = await response.json();
|
|
if (!response.ok) {
|
|
throw new Error(`Error ${response.status}: ${json.message ?? JSON.stringify(json)}`);
|
|
}
|
|
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
|
|
return text;
|
|
}
|
|
async _signedFetch(prompt, options, fields) {
|
|
const { bedrockMethod, endpointHost, provider } = fields;
|
|
const inputBody = BedrockLLMInputOutputAdapter.prepareInput(provider, prompt, this.maxTokens, this.temperature, options.stop ?? this.stopSequences, this.modelKwargs, fields.bedrockMethod);
|
|
const url = new URL(`https://${endpointHost}/model/${this.model}/${bedrockMethod}`);
|
|
const request = new HttpRequest({
|
|
hostname: url.hostname,
|
|
path: url.pathname,
|
|
protocol: url.protocol,
|
|
method: "POST",
|
|
body: JSON.stringify(inputBody),
|
|
query: Object.fromEntries(url.searchParams.entries()),
|
|
headers: {
|
|
// host is required by AWS Signature V4: https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
|
|
host: url.host,
|
|
accept: "application/json",
|
|
"content-type": "application/json",
|
|
},
|
|
});
|
|
const signer = new SignatureV4({
|
|
credentials: this.credentials,
|
|
service: "bedrock",
|
|
region: this.region,
|
|
sha256: Sha256,
|
|
});
|
|
const signedRequest = await signer.sign(request);
|
|
// Send request to AWS using the low-level fetch API
|
|
const response = await this.caller.callWithOptions({ signal: options.signal }, async () => this.fetchFn(url, {
|
|
headers: signedRequest.headers,
|
|
body: signedRequest.body,
|
|
method: signedRequest.method,
|
|
}));
|
|
return response;
|
|
}
|
|
invocationParams(options) {
|
|
return {
|
|
model: this.model,
|
|
region: this.region,
|
|
temperature: this.temperature,
|
|
maxTokens: this.maxTokens,
|
|
stop: options?.stop ?? this.stopSequences,
|
|
modelKwargs: this.modelKwargs,
|
|
};
|
|
}
|
|
async *_streamResponseChunks(prompt, options, runManager) {
|
|
const provider = this.modelProvider;
|
|
const bedrockMethod = provider === "anthropic" ||
|
|
provider === "cohere" ||
|
|
provider === "meta" ||
|
|
provider === "mistral"
|
|
? "invoke-with-response-stream"
|
|
: "invoke";
|
|
const service = "bedrock-runtime";
|
|
const endpointHost = this.endpointHost ?? `${service}.${this.region}.amazonaws.com`;
|
|
// Send request to AWS using the low-level fetch API
|
|
const response = await this._signedFetch(prompt, options, {
|
|
bedrockMethod,
|
|
endpointHost,
|
|
provider,
|
|
});
|
|
if (response.status < 200 || response.status >= 300) {
|
|
throw Error(`Failed to access underlying url '${endpointHost}': got ${response.status} ${response.statusText}: ${await response.text()}`);
|
|
}
|
|
if (provider === "anthropic" ||
|
|
provider === "cohere" ||
|
|
provider === "meta" ||
|
|
provider === "mistral") {
|
|
const reader = response.body?.getReader();
|
|
const decoder = new TextDecoder();
|
|
for await (const chunk of this._readChunks(reader)) {
|
|
const event = this.codec.decode(chunk);
|
|
if ((event.headers[":event-type"] !== undefined &&
|
|
event.headers[":event-type"].value !== "chunk") ||
|
|
event.headers[":content-type"].value !== "application/json") {
|
|
throw Error(`Failed to get event chunk: got ${chunk}`);
|
|
}
|
|
const body = JSON.parse(decoder.decode(event.body));
|
|
if (body.message) {
|
|
throw new Error(body.message);
|
|
}
|
|
if (body.bytes !== undefined) {
|
|
const chunkResult = JSON.parse(decoder.decode(Uint8Array.from(atob(body.bytes), (m) => m.codePointAt(0) ?? 0)));
|
|
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, chunkResult);
|
|
yield new GenerationChunk({
|
|
text,
|
|
generationInfo: {},
|
|
});
|
|
// eslint-disable-next-line no-void
|
|
void runManager?.handleLLMNewToken(text);
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
const json = await response.json();
|
|
const text = BedrockLLMInputOutputAdapter.prepareOutput(provider, json);
|
|
yield new GenerationChunk({
|
|
text,
|
|
generationInfo: {},
|
|
});
|
|
// eslint-disable-next-line no-void
|
|
void runManager?.handleLLMNewToken(text);
|
|
}
|
|
}
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
_readChunks(reader) {
|
|
function _concatChunks(a, b) {
|
|
const newBuffer = new Uint8Array(a.length + b.length);
|
|
newBuffer.set(a);
|
|
newBuffer.set(b, a.length);
|
|
return newBuffer;
|
|
}
|
|
function getMessageLength(buffer) {
|
|
if (buffer.byteLength < PRELUDE_TOTAL_LENGTH_BYTES)
|
|
return 0;
|
|
const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
|
|
return view.getUint32(0, false);
|
|
}
|
|
return {
|
|
async *[Symbol.asyncIterator]() {
|
|
let readResult = await reader.read();
|
|
let buffer = new Uint8Array(0);
|
|
while (!readResult.done) {
|
|
const chunk = readResult.value;
|
|
buffer = _concatChunks(buffer, chunk);
|
|
let messageLength = getMessageLength(buffer);
|
|
while (buffer.byteLength >= PRELUDE_TOTAL_LENGTH_BYTES &&
|
|
buffer.byteLength >= messageLength) {
|
|
yield buffer.slice(0, messageLength);
|
|
buffer = buffer.slice(messageLength);
|
|
messageLength = getMessageLength(buffer);
|
|
}
|
|
readResult = await reader.read();
|
|
}
|
|
},
|
|
};
|
|
}
|
|
}
|
|
function isInferenceModel(modelId) {
|
|
const parts = modelId.split(".");
|
|
return AWS_REGIONS.some((region) => parts[0] === region);
|
|
}
|
|
function getModelProvider(modelId) {
|
|
const parts = modelId.split(".");
|
|
if (isInferenceModel(modelId)) {
|
|
return parts[1];
|
|
}
|
|
else {
|
|
return parts[0];
|
|
}
|
|
}
|