agsamantha/node_modules/@langchain/openai/dist/llms.js

570 lines
21 KiB
JavaScript
Raw Normal View History

2024-10-02 15:15:21 -05:00
import { OpenAI as OpenAIClient } from "openai";
import { calculateMaxTokens } from "@langchain/core/language_models/base";
import { GenerationChunk } from "@langchain/core/outputs";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { BaseLLM, } from "@langchain/core/language_models/llms";
import { chunkArray } from "@langchain/core/utils/chunk_array";
import { getEndpoint } from "./utils/azure.js";
import { OpenAIChat } from "./legacy.js";
import { wrapOpenAIClientError } from "./utils/openai.js";
export { OpenAIChat };
/**
* Wrapper around OpenAI large language models.
*
* To use you should have the `openai` package installed, with the
* `OPENAI_API_KEY` environment variable set.
*
* To use with Azure you should have the `openai` package installed, with the
* `AZURE_OPENAI_API_KEY`,
* `AZURE_OPENAI_API_INSTANCE_NAME`,
* `AZURE_OPENAI_API_DEPLOYMENT_NAME`
* and `AZURE_OPENAI_API_VERSION` environment variable set.
*
* @remarks
* Any parameters that are valid to be passed to {@link
* https://platform.openai.com/docs/api-reference/completions/create |
* `openai.createCompletion`} can be passed through {@link modelKwargs}, even
* if not explicitly available on this class.
* @example
* ```typescript
* const model = new OpenAI({
* modelName: "gpt-4",
* temperature: 0.7,
* maxTokens: 1000,
* maxRetries: 5,
* });
*
* const res = await model.invoke(
* "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:"
* );
* console.log({ res });
* ```
*/
export class OpenAI extends BaseLLM {
static lc_name() {
return "OpenAI";
}
get callKeys() {
return [...super.callKeys, "options"];
}
get lc_secrets() {
return {
openAIApiKey: "OPENAI_API_KEY",
apiKey: "OPENAI_API_KEY",
azureOpenAIApiKey: "AZURE_OPENAI_API_KEY",
organization: "OPENAI_ORGANIZATION",
};
}
get lc_aliases() {
return {
modelName: "model",
openAIApiKey: "openai_api_key",
apiKey: "openai_api_key",
azureOpenAIApiVersion: "azure_openai_api_version",
azureOpenAIApiKey: "azure_openai_api_key",
azureOpenAIApiInstanceName: "azure_openai_api_instance_name",
azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name",
};
}
constructor(fields,
/** @deprecated */
configuration) {
let model = fields?.model ?? fields?.modelName;
if ((model?.startsWith("gpt-3.5-turbo") || model?.startsWith("gpt-4")) &&
!model?.includes("-instruct")) {
console.warn([
`Your chosen OpenAI model, "${model}", is a chat model and not a text-in/text-out LLM.`,
`Passing it into the "OpenAI" class is deprecated and only permitted for backwards-compatibility. You may experience odd behavior.`,
`Please use the "ChatOpenAI" class instead.`,
"",
`See this page for more information:`,
"|",
`└> https://js.langchain.com/docs/integrations/chat/openai`,
].join("\n"));
// eslint-disable-next-line no-constructor-return
return new OpenAIChat(fields, configuration);
}
super(fields ?? {});
Object.defineProperty(this, "lc_serializable", {
enumerable: true,
configurable: true,
writable: true,
value: true
});
Object.defineProperty(this, "temperature", {
enumerable: true,
configurable: true,
writable: true,
value: 0.7
});
Object.defineProperty(this, "maxTokens", {
enumerable: true,
configurable: true,
writable: true,
value: 256
});
Object.defineProperty(this, "topP", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "frequencyPenalty", {
enumerable: true,
configurable: true,
writable: true,
value: 0
});
Object.defineProperty(this, "presencePenalty", {
enumerable: true,
configurable: true,
writable: true,
value: 0
});
Object.defineProperty(this, "n", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "bestOf", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "logitBias", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "modelName", {
enumerable: true,
configurable: true,
writable: true,
value: "gpt-3.5-turbo-instruct"
});
Object.defineProperty(this, "model", {
enumerable: true,
configurable: true,
writable: true,
value: "gpt-3.5-turbo-instruct"
});
Object.defineProperty(this, "modelKwargs", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "batchSize", {
enumerable: true,
configurable: true,
writable: true,
value: 20
});
Object.defineProperty(this, "timeout", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "stop", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "stopSequences", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "user", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "streaming", {
enumerable: true,
configurable: true,
writable: true,
value: false
});
Object.defineProperty(this, "openAIApiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "apiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIApiVersion", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIApiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureADTokenProvider", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIApiInstanceName", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIApiDeploymentName", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIBasePath", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "organization", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "client", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "clientConfig", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
model = model ?? this.model;
this.openAIApiKey =
fields?.apiKey ??
fields?.openAIApiKey ??
getEnvironmentVariable("OPENAI_API_KEY");
this.apiKey = this.openAIApiKey;
this.azureOpenAIApiKey =
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");
this.azureADTokenProvider = fields?.azureADTokenProvider ?? undefined;
if (!this.azureOpenAIApiKey && !this.apiKey && !this.azureADTokenProvider) {
throw new Error("OpenAI or Azure OpenAI API key or Token Provider not found");
}
this.azureOpenAIApiInstanceName =
fields?.azureOpenAIApiInstanceName ??
getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME");
this.azureOpenAIApiDeploymentName =
(fields?.azureOpenAIApiCompletionsDeploymentName ||
fields?.azureOpenAIApiDeploymentName) ??
(getEnvironmentVariable("AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME") ||
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME"));
this.azureOpenAIApiVersion =
fields?.azureOpenAIApiVersion ??
getEnvironmentVariable("AZURE_OPENAI_API_VERSION");
this.azureOpenAIBasePath =
fields?.azureOpenAIBasePath ??
getEnvironmentVariable("AZURE_OPENAI_BASE_PATH");
this.organization =
fields?.configuration?.organization ??
getEnvironmentVariable("OPENAI_ORGANIZATION");
this.modelName = model;
this.model = model;
this.modelKwargs = fields?.modelKwargs ?? {};
this.batchSize = fields?.batchSize ?? this.batchSize;
this.timeout = fields?.timeout;
this.temperature = fields?.temperature ?? this.temperature;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.topP = fields?.topP ?? this.topP;
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty;
this.n = fields?.n ?? this.n;
this.bestOf = fields?.bestOf ?? this.bestOf;
this.logitBias = fields?.logitBias;
this.stop = fields?.stopSequences ?? fields?.stop;
this.stopSequences = fields?.stopSequences;
this.user = fields?.user;
this.streaming = fields?.streaming ?? false;
if (this.streaming && this.bestOf && this.bestOf > 1) {
throw new Error("Cannot stream results when bestOf > 1");
}
if (this.azureOpenAIApiKey || this.azureADTokenProvider) {
if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) {
throw new Error("Azure OpenAI API instance name not found");
}
if (!this.azureOpenAIApiDeploymentName) {
throw new Error("Azure OpenAI API deployment name not found");
}
if (!this.azureOpenAIApiVersion) {
throw new Error("Azure OpenAI API version not found");
}
this.apiKey = this.apiKey ?? "";
}
this.clientConfig = {
apiKey: this.apiKey,
organization: this.organization,
baseURL: configuration?.basePath ?? fields?.configuration?.basePath,
dangerouslyAllowBrowser: true,
defaultHeaders: configuration?.baseOptions?.headers ??
fields?.configuration?.baseOptions?.headers,
defaultQuery: configuration?.baseOptions?.params ??
fields?.configuration?.baseOptions?.params,
...configuration,
...fields?.configuration,
};
}
/**
* Get the parameters used to invoke the model
*/
invocationParams(options) {
return {
model: this.model,
temperature: this.temperature,
max_tokens: this.maxTokens,
top_p: this.topP,
frequency_penalty: this.frequencyPenalty,
presence_penalty: this.presencePenalty,
n: this.n,
best_of: this.bestOf,
logit_bias: this.logitBias,
stop: options?.stop ?? this.stopSequences,
user: this.user,
stream: this.streaming,
...this.modelKwargs,
};
}
/** @ignore */
_identifyingParams() {
return {
model_name: this.model,
...this.invocationParams(),
...this.clientConfig,
};
}
/**
* Get the identifying parameters for the model
*/
identifyingParams() {
return this._identifyingParams();
}
/**
* Call out to OpenAI's endpoint with k unique prompts
*
* @param [prompts] - The prompts to pass into the model.
* @param [options] - Optional list of stop words to use when generating.
* @param [runManager] - Optional callback manager to use when generating.
*
* @returns The full LLM output.
*
* @example
* ```ts
* import { OpenAI } from "langchain/llms/openai";
* const openai = new OpenAI();
* const response = await openai.generate(["Tell me a joke."]);
* ```
*/
async _generate(prompts, options, runManager) {
const subPrompts = chunkArray(prompts, this.batchSize);
const choices = [];
const tokenUsage = {};
const params = this.invocationParams(options);
if (params.max_tokens === -1) {
if (prompts.length !== 1) {
throw new Error("max_tokens set to -1 not supported for multiple inputs");
}
params.max_tokens = await calculateMaxTokens({
prompt: prompts[0],
// Cast here to allow for other models that may not fit the union
modelName: this.model,
});
}
for (let i = 0; i < subPrompts.length; i += 1) {
const data = params.stream
? await (async () => {
const choices = [];
let response;
const stream = await this.completionWithRetry({
...params,
stream: true,
prompt: subPrompts[i],
}, options);
for await (const message of stream) {
// on the first message set the response properties
if (!response) {
response = {
id: message.id,
object: message.object,
created: message.created,
model: message.model,
};
}
// on all messages, update choice
for (const part of message.choices) {
if (!choices[part.index]) {
choices[part.index] = part;
}
else {
const choice = choices[part.index];
choice.text += part.text;
choice.finish_reason = part.finish_reason;
choice.logprobs = part.logprobs;
}
void runManager?.handleLLMNewToken(part.text, {
prompt: Math.floor(part.index / this.n),
completion: part.index % this.n,
});
}
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}
return { ...response, choices };
})()
: await this.completionWithRetry({
...params,
stream: false,
prompt: subPrompts[i],
}, {
signal: options.signal,
...options.options,
});
choices.push(...data.choices);
const { completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens, } = data.usage
? data.usage
: {
completion_tokens: undefined,
prompt_tokens: undefined,
total_tokens: undefined,
};
if (completionTokens) {
tokenUsage.completionTokens =
(tokenUsage.completionTokens ?? 0) + completionTokens;
}
if (promptTokens) {
tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens;
}
if (totalTokens) {
tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
}
}
const generations = chunkArray(choices, this.n).map((promptChoices) => promptChoices.map((choice) => ({
text: choice.text ?? "",
generationInfo: {
finishReason: choice.finish_reason,
logprobs: choice.logprobs,
},
})));
return {
generations,
llmOutput: { tokenUsage },
};
}
// TODO(jacoblee): Refactor with _generate(..., {stream: true}) implementation?
async *_streamResponseChunks(input, options, runManager) {
const params = {
...this.invocationParams(options),
prompt: input,
stream: true,
};
const stream = await this.completionWithRetry(params, options);
for await (const data of stream) {
const choice = data?.choices[0];
if (!choice) {
continue;
}
const chunk = new GenerationChunk({
text: choice.text,
generationInfo: {
finishReason: choice.finish_reason,
},
});
yield chunk;
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(chunk.text ?? "");
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}
}
async completionWithRetry(request, options) {
const requestOptions = this._getClientOptions(options);
return this.caller.call(async () => {
try {
const res = await this.client.completions.create(request, requestOptions);
return res;
}
catch (e) {
const error = wrapOpenAIClientError(e);
throw error;
}
});
}
/**
* Calls the OpenAI API with retry logic in case of failures.
* @param request The request to send to the OpenAI API.
* @param options Optional configuration for the API call.
* @returns The response from the OpenAI API.
*/
_getClientOptions(options) {
if (!this.client) {
const openAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
azureOpenAIApiKey: this.azureOpenAIApiKey,
azureOpenAIBasePath: this.azureOpenAIBasePath,
baseURL: this.clientConfig.baseURL,
};
const endpoint = getEndpoint(openAIEndpointConfig);
const params = {
...this.clientConfig,
baseURL: endpoint,
timeout: this.timeout,
maxRetries: 0,
};
if (!params.baseURL) {
delete params.baseURL;
}
this.client = new OpenAIClient(params);
}
const requestOptions = {
...this.clientConfig,
...options,
};
if (this.azureOpenAIApiKey) {
requestOptions.headers = {
"api-key": this.azureOpenAIApiKey,
...requestOptions.headers,
};
requestOptions.query = {
"api-version": this.azureOpenAIApiVersion,
...requestOptions.query,
};
}
return requestOptions;
}
_llmType() {
return "openai";
}
}