205 lines
7.1 KiB
JavaScript
205 lines
7.1 KiB
JavaScript
|
import { LLM } from "@langchain/core/language_models/llms";
|
||
|
import { getEnvironmentVariable } from "@langchain/core/utils/env";
|
||
|
/**
|
||
|
* Class representing the AI21 language model. It extends the LLM (Large
|
||
|
* Language Model) class, providing a standard interface for interacting
|
||
|
* with the AI21 language model.
|
||
|
*/
|
||
|
export class AI21 extends LLM {
|
||
|
constructor(fields) {
|
||
|
super(fields ?? {});
|
||
|
Object.defineProperty(this, "lc_serializable", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: true
|
||
|
});
|
||
|
Object.defineProperty(this, "model", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "j2-jumbo-instruct"
|
||
|
});
|
||
|
Object.defineProperty(this, "temperature", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: 0.7
|
||
|
});
|
||
|
Object.defineProperty(this, "maxTokens", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: 1024
|
||
|
});
|
||
|
Object.defineProperty(this, "minTokens", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: 0
|
||
|
});
|
||
|
Object.defineProperty(this, "topP", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: 1
|
||
|
});
|
||
|
Object.defineProperty(this, "presencePenalty", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: AI21.getDefaultAI21PenaltyData()
|
||
|
});
|
||
|
Object.defineProperty(this, "countPenalty", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: AI21.getDefaultAI21PenaltyData()
|
||
|
});
|
||
|
Object.defineProperty(this, "frequencyPenalty", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: AI21.getDefaultAI21PenaltyData()
|
||
|
});
|
||
|
Object.defineProperty(this, "numResults", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: 1
|
||
|
});
|
||
|
Object.defineProperty(this, "logitBias", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "ai21ApiKey", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "stop", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "baseUrl", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
this.model = fields?.model ?? this.model;
|
||
|
this.temperature = fields?.temperature ?? this.temperature;
|
||
|
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
|
||
|
this.minTokens = fields?.minTokens ?? this.minTokens;
|
||
|
this.topP = fields?.topP ?? this.topP;
|
||
|
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty;
|
||
|
this.countPenalty = fields?.countPenalty ?? this.countPenalty;
|
||
|
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
|
||
|
this.numResults = fields?.numResults ?? this.numResults;
|
||
|
this.logitBias = fields?.logitBias;
|
||
|
this.ai21ApiKey =
|
||
|
fields?.ai21ApiKey ?? getEnvironmentVariable("AI21_API_KEY");
|
||
|
this.stop = fields?.stop;
|
||
|
this.baseUrl = fields?.baseUrl;
|
||
|
}
|
||
|
/**
|
||
|
* Method to validate the environment. It checks if the AI21 API key is
|
||
|
* set. If not, it throws an error.
|
||
|
*/
|
||
|
validateEnvironment() {
|
||
|
if (!this.ai21ApiKey) {
|
||
|
throw new Error(`No AI21 API key found. Please set it as "AI21_API_KEY" in your environment variables.`);
|
||
|
}
|
||
|
}
|
||
|
/**
|
||
|
* Static method to get the default penalty data for AI21.
|
||
|
* @returns AI21PenaltyData
|
||
|
*/
|
||
|
static getDefaultAI21PenaltyData() {
|
||
|
return {
|
||
|
scale: 0,
|
||
|
applyToWhitespaces: true,
|
||
|
applyToPunctuations: true,
|
||
|
applyToNumbers: true,
|
||
|
applyToStopwords: true,
|
||
|
applyToEmojis: true,
|
||
|
};
|
||
|
}
|
||
|
/** Get the type of LLM. */
|
||
|
_llmType() {
|
||
|
return "ai21";
|
||
|
}
|
||
|
/** Get the default parameters for calling AI21 API. */
|
||
|
get defaultParams() {
|
||
|
return {
|
||
|
temperature: this.temperature,
|
||
|
maxTokens: this.maxTokens,
|
||
|
minTokens: this.minTokens,
|
||
|
topP: this.topP,
|
||
|
presencePenalty: this.presencePenalty,
|
||
|
countPenalty: this.countPenalty,
|
||
|
frequencyPenalty: this.frequencyPenalty,
|
||
|
numResults: this.numResults,
|
||
|
logitBias: this.logitBias,
|
||
|
};
|
||
|
}
|
||
|
/** Get the identifying parameters for this LLM. */
|
||
|
get identifyingParams() {
|
||
|
return { ...this.defaultParams, model: this.model };
|
||
|
}
|
||
|
/** Call out to AI21's complete endpoint.
|
||
|
Args:
|
||
|
prompt: The prompt to pass into the model.
|
||
|
stop: Optional list of stop words to use when generating.
|
||
|
|
||
|
Returns:
|
||
|
The string generated by the model.
|
||
|
|
||
|
Example:
|
||
|
let response = ai21._call("Tell me a joke.");
|
||
|
*/
|
||
|
async _call(prompt, options) {
|
||
|
let stop = options?.stop;
|
||
|
this.validateEnvironment();
|
||
|
if (this.stop && stop && this.stop.length > 0 && stop.length > 0) {
|
||
|
throw new Error("`stop` found in both the input and default params.");
|
||
|
}
|
||
|
stop = this.stop ?? stop ?? [];
|
||
|
const baseUrl = this.baseUrl ?? this.model === "j1-grande-instruct"
|
||
|
? "https://api.ai21.com/studio/v1/experimental"
|
||
|
: "https://api.ai21.com/studio/v1";
|
||
|
const url = `${baseUrl}/${this.model}/complete`;
|
||
|
const headers = {
|
||
|
Authorization: `Bearer ${this.ai21ApiKey}`,
|
||
|
"Content-Type": "application/json",
|
||
|
};
|
||
|
const data = { prompt, stopSequences: stop, ...this.defaultParams };
|
||
|
const responseData = await this.caller.callWithOptions({}, async () => {
|
||
|
const response = await fetch(url, {
|
||
|
method: "POST",
|
||
|
headers,
|
||
|
body: JSON.stringify(data),
|
||
|
signal: options.signal,
|
||
|
});
|
||
|
if (!response.ok) {
|
||
|
const error = new Error(`AI21 call failed with status code ${response.status}`);
|
||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||
|
error.response = response;
|
||
|
throw error;
|
||
|
}
|
||
|
return response.json();
|
||
|
});
|
||
|
if (!responseData.completions ||
|
||
|
responseData.completions.length === 0 ||
|
||
|
!responseData.completions[0].data) {
|
||
|
throw new Error("No completions found in response");
|
||
|
}
|
||
|
return responseData.completions[0].data.text ?? "";
|
||
|
}
|
||
|
}
|