agsamantha/node_modules/@langchain/community/dist/llms/ai21.js
2024-10-02 15:15:21 -05:00

204 lines
7.1 KiB
JavaScript

import { LLM } from "@langchain/core/language_models/llms";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
/**
* Class representing the AI21 language model. It extends the LLM (Large
* Language Model) class, providing a standard interface for interacting
* with the AI21 language model.
*/
export class AI21 extends LLM {
constructor(fields) {
super(fields ?? {});
Object.defineProperty(this, "lc_serializable", {
enumerable: true,
configurable: true,
writable: true,
value: true
});
Object.defineProperty(this, "model", {
enumerable: true,
configurable: true,
writable: true,
value: "j2-jumbo-instruct"
});
Object.defineProperty(this, "temperature", {
enumerable: true,
configurable: true,
writable: true,
value: 0.7
});
Object.defineProperty(this, "maxTokens", {
enumerable: true,
configurable: true,
writable: true,
value: 1024
});
Object.defineProperty(this, "minTokens", {
enumerable: true,
configurable: true,
writable: true,
value: 0
});
Object.defineProperty(this, "topP", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "presencePenalty", {
enumerable: true,
configurable: true,
writable: true,
value: AI21.getDefaultAI21PenaltyData()
});
Object.defineProperty(this, "countPenalty", {
enumerable: true,
configurable: true,
writable: true,
value: AI21.getDefaultAI21PenaltyData()
});
Object.defineProperty(this, "frequencyPenalty", {
enumerable: true,
configurable: true,
writable: true,
value: AI21.getDefaultAI21PenaltyData()
});
Object.defineProperty(this, "numResults", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "logitBias", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "ai21ApiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "stop", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "baseUrl", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.model = fields?.model ?? this.model;
this.temperature = fields?.temperature ?? this.temperature;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.minTokens = fields?.minTokens ?? this.minTokens;
this.topP = fields?.topP ?? this.topP;
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty;
this.countPenalty = fields?.countPenalty ?? this.countPenalty;
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
this.numResults = fields?.numResults ?? this.numResults;
this.logitBias = fields?.logitBias;
this.ai21ApiKey =
fields?.ai21ApiKey ?? getEnvironmentVariable("AI21_API_KEY");
this.stop = fields?.stop;
this.baseUrl = fields?.baseUrl;
}
/**
* Method to validate the environment. It checks if the AI21 API key is
* set. If not, it throws an error.
*/
validateEnvironment() {
if (!this.ai21ApiKey) {
throw new Error(`No AI21 API key found. Please set it as "AI21_API_KEY" in your environment variables.`);
}
}
/**
* Static method to get the default penalty data for AI21.
* @returns AI21PenaltyData
*/
static getDefaultAI21PenaltyData() {
return {
scale: 0,
applyToWhitespaces: true,
applyToPunctuations: true,
applyToNumbers: true,
applyToStopwords: true,
applyToEmojis: true,
};
}
/** Get the type of LLM. */
_llmType() {
return "ai21";
}
/** Get the default parameters for calling AI21 API. */
get defaultParams() {
return {
temperature: this.temperature,
maxTokens: this.maxTokens,
minTokens: this.minTokens,
topP: this.topP,
presencePenalty: this.presencePenalty,
countPenalty: this.countPenalty,
frequencyPenalty: this.frequencyPenalty,
numResults: this.numResults,
logitBias: this.logitBias,
};
}
/** Get the identifying parameters for this LLM. */
get identifyingParams() {
return { ...this.defaultParams, model: this.model };
}
/** Call out to AI21's complete endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
let response = ai21._call("Tell me a joke.");
*/
async _call(prompt, options) {
let stop = options?.stop;
this.validateEnvironment();
if (this.stop && stop && this.stop.length > 0 && stop.length > 0) {
throw new Error("`stop` found in both the input and default params.");
}
stop = this.stop ?? stop ?? [];
const baseUrl = this.baseUrl ?? this.model === "j1-grande-instruct"
? "https://api.ai21.com/studio/v1/experimental"
: "https://api.ai21.com/studio/v1";
const url = `${baseUrl}/${this.model}/complete`;
const headers = {
Authorization: `Bearer ${this.ai21ApiKey}`,
"Content-Type": "application/json",
};
const data = { prompt, stopSequences: stop, ...this.defaultParams };
const responseData = await this.caller.callWithOptions({}, async () => {
const response = await fetch(url, {
method: "POST",
headers,
body: JSON.stringify(data),
signal: options.signal,
});
if (!response.ok) {
const error = new Error(`AI21 call failed with status code ${response.status}`);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error.response = response;
throw error;
}
return response.json();
});
if (!responseData.completions ||
responseData.completions.length === 0 ||
!responseData.completions[0].data) {
throw new Error("No completions found in response");
}
return responseData.completions[0].data.text ?? "";
}
}