agsamantha/node_modules/@langchain/openai/dist/legacy.js

455 lines
15 KiB
JavaScript
Raw Normal View History

2024-10-02 15:15:21 -05:00
import { OpenAI as OpenAIClient } from "openai";
import { GenerationChunk } from "@langchain/core/outputs";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { LLM } from "@langchain/core/language_models/llms";
import { getEndpoint } from "./utils/azure.js";
import { wrapOpenAIClientError } from "./utils/openai.js";
/**
* @deprecated For legacy compatibility. Use ChatOpenAI instead.
*
* Wrapper around OpenAI large language models that use the Chat endpoint.
*
* To use you should have the `openai` package installed, with the
* `OPENAI_API_KEY` environment variable set.
*
* To use with Azure you should have the `openai` package installed, with the
* `AZURE_OPENAI_API_KEY`,
* `AZURE_OPENAI_API_INSTANCE_NAME`,
* `AZURE_OPENAI_API_DEPLOYMENT_NAME`
* and `AZURE_OPENAI_API_VERSION` environment variable set.
*
* @remarks
* Any parameters that are valid to be passed to {@link
* https://platform.openai.com/docs/api-reference/chat/create |
* `openai.createCompletion`} can be passed through {@link modelKwargs}, even
* if not explicitly available on this class.
*
* @augments BaseLLM
* @augments OpenAIInput
* @augments AzureOpenAIChatInput
* @example
* ```typescript
* const model = new OpenAIChat({
* prefixMessages: [
* {
* role: "system",
* content: "You are a helpful assistant that answers in pirate language",
* },
* ],
* maxTokens: 50,
* });
*
* const res = await model.invoke(
* "What would be a good company name for a company that makes colorful socks?"
* );
* console.log({ res });
* ```
*/
export class OpenAIChat extends LLM {
static lc_name() {
return "OpenAIChat";
}
get callKeys() {
return [...super.callKeys, "options", "promptIndex"];
}
get lc_secrets() {
return {
openAIApiKey: "OPENAI_API_KEY",
azureOpenAIApiKey: "AZURE_OPENAI_API_KEY",
organization: "OPENAI_ORGANIZATION",
};
}
get lc_aliases() {
return {
modelName: "model",
openAIApiKey: "openai_api_key",
azureOpenAIApiVersion: "azure_openai_api_version",
azureOpenAIApiKey: "azure_openai_api_key",
azureOpenAIApiInstanceName: "azure_openai_api_instance_name",
azureOpenAIApiDeploymentName: "azure_openai_api_deployment_name",
};
}
constructor(fields,
/** @deprecated */
configuration) {
super(fields ?? {});
Object.defineProperty(this, "lc_serializable", {
enumerable: true,
configurable: true,
writable: true,
value: true
});
Object.defineProperty(this, "temperature", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "topP", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "frequencyPenalty", {
enumerable: true,
configurable: true,
writable: true,
value: 0
});
Object.defineProperty(this, "presencePenalty", {
enumerable: true,
configurable: true,
writable: true,
value: 0
});
Object.defineProperty(this, "n", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "logitBias", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "maxTokens", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "modelName", {
enumerable: true,
configurable: true,
writable: true,
value: "gpt-3.5-turbo"
});
Object.defineProperty(this, "model", {
enumerable: true,
configurable: true,
writable: true,
value: "gpt-3.5-turbo"
});
Object.defineProperty(this, "prefixMessages", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "modelKwargs", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "timeout", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "stop", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "user", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "streaming", {
enumerable: true,
configurable: true,
writable: true,
value: false
});
Object.defineProperty(this, "openAIApiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIApiVersion", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIApiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIApiInstanceName", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIApiDeploymentName", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "azureOpenAIBasePath", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "organization", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "client", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "clientConfig", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.openAIApiKey =
fields?.apiKey ??
fields?.openAIApiKey ??
getEnvironmentVariable("OPENAI_API_KEY");
this.azureOpenAIApiKey =
fields?.azureOpenAIApiKey ??
getEnvironmentVariable("AZURE_OPENAI_API_KEY");
if (!this.azureOpenAIApiKey && !this.openAIApiKey) {
throw new Error("OpenAI or Azure OpenAI API key not found");
}
this.azureOpenAIApiInstanceName =
fields?.azureOpenAIApiInstanceName ??
getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME");
this.azureOpenAIApiDeploymentName =
(fields?.azureOpenAIApiCompletionsDeploymentName ||
fields?.azureOpenAIApiDeploymentName) ??
(getEnvironmentVariable("AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME") ||
getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME"));
this.azureOpenAIApiVersion =
fields?.azureOpenAIApiVersion ??
getEnvironmentVariable("AZURE_OPENAI_API_VERSION");
this.azureOpenAIBasePath =
fields?.azureOpenAIBasePath ??
getEnvironmentVariable("AZURE_OPENAI_BASE_PATH");
this.organization =
fields?.configuration?.organization ??
getEnvironmentVariable("OPENAI_ORGANIZATION");
this.modelName = fields?.model ?? fields?.modelName ?? this.modelName;
this.prefixMessages = fields?.prefixMessages ?? this.prefixMessages;
this.modelKwargs = fields?.modelKwargs ?? {};
this.timeout = fields?.timeout;
this.temperature = fields?.temperature ?? this.temperature;
this.topP = fields?.topP ?? this.topP;
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty;
this.n = fields?.n ?? this.n;
this.logitBias = fields?.logitBias;
this.maxTokens = fields?.maxTokens;
this.stop = fields?.stop;
this.user = fields?.user;
this.streaming = fields?.streaming ?? false;
if (this.n > 1) {
throw new Error("Cannot use n > 1 in OpenAIChat LLM. Use ChatOpenAI Chat Model instead.");
}
if (this.azureOpenAIApiKey) {
if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) {
throw new Error("Azure OpenAI API instance name not found");
}
if (!this.azureOpenAIApiDeploymentName) {
throw new Error("Azure OpenAI API deployment name not found");
}
if (!this.azureOpenAIApiVersion) {
throw new Error("Azure OpenAI API version not found");
}
this.openAIApiKey = this.openAIApiKey ?? "";
}
this.clientConfig = {
apiKey: this.openAIApiKey,
organization: this.organization,
baseURL: configuration?.basePath ?? fields?.configuration?.basePath,
dangerouslyAllowBrowser: true,
defaultHeaders: configuration?.baseOptions?.headers ??
fields?.configuration?.baseOptions?.headers,
defaultQuery: configuration?.baseOptions?.params ??
fields?.configuration?.baseOptions?.params,
...configuration,
...fields?.configuration,
};
}
/**
* Get the parameters used to invoke the model
*/
invocationParams(options) {
return {
model: this.modelName,
temperature: this.temperature,
top_p: this.topP,
frequency_penalty: this.frequencyPenalty,
presence_penalty: this.presencePenalty,
n: this.n,
logit_bias: this.logitBias,
max_tokens: this.maxTokens === -1 ? undefined : this.maxTokens,
stop: options?.stop ?? this.stop,
user: this.user,
stream: this.streaming,
...this.modelKwargs,
};
}
/** @ignore */
_identifyingParams() {
return {
model_name: this.modelName,
...this.invocationParams(),
...this.clientConfig,
};
}
/**
* Get the identifying parameters for the model
*/
identifyingParams() {
return {
model_name: this.modelName,
...this.invocationParams(),
...this.clientConfig,
};
}
/**
* Formats the messages for the OpenAI API.
* @param prompt The prompt to be formatted.
* @returns Array of formatted messages.
*/
formatMessages(prompt) {
const message = {
role: "user",
content: prompt,
};
return this.prefixMessages ? [...this.prefixMessages, message] : [message];
}
async *_streamResponseChunks(prompt, options, runManager) {
const params = {
...this.invocationParams(options),
messages: this.formatMessages(prompt),
stream: true,
};
const stream = await this.completionWithRetry(params, options);
for await (const data of stream) {
const choice = data?.choices[0];
if (!choice) {
continue;
}
const { delta } = choice;
const generationChunk = new GenerationChunk({
text: delta.content ?? "",
});
yield generationChunk;
const newTokenIndices = {
prompt: options.promptIndex ?? 0,
completion: choice.index ?? 0,
};
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(generationChunk.text ?? "", newTokenIndices);
}
if (options.signal?.aborted) {
throw new Error("AbortError");
}
}
/** @ignore */
async _call(prompt, options, runManager) {
const params = this.invocationParams(options);
if (params.stream) {
const stream = await this._streamResponseChunks(prompt, options, runManager);
let finalChunk;
for await (const chunk of stream) {
if (finalChunk === undefined) {
finalChunk = chunk;
}
else {
finalChunk = finalChunk.concat(chunk);
}
}
return finalChunk?.text ?? "";
}
else {
const response = await this.completionWithRetry({
...params,
stream: false,
messages: this.formatMessages(prompt),
}, {
signal: options.signal,
...options.options,
});
return response?.choices[0]?.message?.content ?? "";
}
}
async completionWithRetry(request, options) {
const requestOptions = this._getClientOptions(options);
return this.caller.call(async () => {
try {
const res = await this.client.chat.completions.create(request, requestOptions);
return res;
}
catch (e) {
const error = wrapOpenAIClientError(e);
throw error;
}
});
}
/** @ignore */
_getClientOptions(options) {
if (!this.client) {
const openAIEndpointConfig = {
azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName,
azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName,
azureOpenAIApiKey: this.azureOpenAIApiKey,
azureOpenAIBasePath: this.azureOpenAIBasePath,
baseURL: this.clientConfig.baseURL,
};
const endpoint = getEndpoint(openAIEndpointConfig);
const params = {
...this.clientConfig,
baseURL: endpoint,
timeout: this.timeout,
maxRetries: 0,
};
if (!params.baseURL) {
delete params.baseURL;
}
this.client = new OpenAIClient(params);
}
const requestOptions = {
...this.clientConfig,
...options,
};
if (this.azureOpenAIApiKey) {
requestOptions.headers = {
"api-key": this.azureOpenAIApiKey,
...requestOptions.headers,
};
requestOptions.query = {
"api-version": this.azureOpenAIApiVersion,
...requestOptions.query,
};
}
return requestOptions;
}
_llmType() {
return "openai";
}
}