import { OpenAI as OpenAIClient } from "openai"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { Embeddings } from "@langchain/core/embeddings"; import { chunkArray } from "@langchain/core/utils/chunk_array"; import { getEndpoint } from "./utils/azure.js"; import { wrapOpenAIClientError } from "./utils/openai.js"; /** * Class for generating embeddings using the OpenAI API. Extends the * Embeddings class and implements OpenAIEmbeddingsParams and * AzureOpenAIInput. * @example * ```typescript * // Embed a query using OpenAIEmbeddings to generate embeddings for a given text * const model = new OpenAIEmbeddings(); * const res = await model.embedQuery( * "What would be a good company name for a company that makes colorful socks?", * ); * console.log({ res }); * * ``` */ export class OpenAIEmbeddings extends Embeddings { constructor(fields, configuration) { const fieldsWithDefaults = { maxConcurrency: 2, ...fields }; super(fieldsWithDefaults); Object.defineProperty(this, "modelName", { enumerable: true, configurable: true, writable: true, value: "text-embedding-ada-002" }); Object.defineProperty(this, "model", { enumerable: true, configurable: true, writable: true, value: "text-embedding-ada-002" }); Object.defineProperty(this, "batchSize", { enumerable: true, configurable: true, writable: true, value: 512 }); // TODO: Update to `false` on next minor release (see: https://github.com/langchain-ai/langchainjs/pull/3612) Object.defineProperty(this, "stripNewLines", { enumerable: true, configurable: true, writable: true, value: true }); /** * The number of dimensions the resulting output embeddings should have. * Only supported in `text-embedding-3` and later models. */ Object.defineProperty(this, "dimensions", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "timeout", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "azureOpenAIApiVersion", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "azureOpenAIApiKey", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "azureADTokenProvider", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "azureOpenAIApiInstanceName", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "azureOpenAIApiDeploymentName", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "azureOpenAIBasePath", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "organization", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "client", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "clientConfig", { enumerable: true, configurable: true, writable: true, value: void 0 }); let apiKey = fieldsWithDefaults?.apiKey ?? fieldsWithDefaults?.openAIApiKey ?? getEnvironmentVariable("OPENAI_API_KEY"); const azureApiKey = fieldsWithDefaults?.azureOpenAIApiKey ?? getEnvironmentVariable("AZURE_OPENAI_API_KEY"); this.azureADTokenProvider = fields?.azureADTokenProvider ?? undefined; if (!azureApiKey && !apiKey && !this.azureADTokenProvider) { throw new Error("OpenAI or Azure OpenAI API key or Token Provider not found"); } const azureApiInstanceName = fieldsWithDefaults?.azureOpenAIApiInstanceName ?? getEnvironmentVariable("AZURE_OPENAI_API_INSTANCE_NAME"); const azureApiDeploymentName = (fieldsWithDefaults?.azureOpenAIApiEmbeddingsDeploymentName || fieldsWithDefaults?.azureOpenAIApiDeploymentName) ?? (getEnvironmentVariable("AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME") || getEnvironmentVariable("AZURE_OPENAI_API_DEPLOYMENT_NAME")); const azureApiVersion = fieldsWithDefaults?.azureOpenAIApiVersion ?? getEnvironmentVariable("AZURE_OPENAI_API_VERSION"); this.azureOpenAIBasePath = fieldsWithDefaults?.azureOpenAIBasePath ?? getEnvironmentVariable("AZURE_OPENAI_BASE_PATH"); this.organization = fieldsWithDefaults?.configuration?.organization ?? getEnvironmentVariable("OPENAI_ORGANIZATION"); this.modelName = fieldsWithDefaults?.model ?? fieldsWithDefaults?.modelName ?? this.model; this.model = this.modelName; this.batchSize = fieldsWithDefaults?.batchSize ?? (azureApiKey ? 1 : this.batchSize); this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines; this.timeout = fieldsWithDefaults?.timeout; this.dimensions = fieldsWithDefaults?.dimensions; this.azureOpenAIApiVersion = azureApiVersion; this.azureOpenAIApiKey = azureApiKey; this.azureOpenAIApiInstanceName = azureApiInstanceName; this.azureOpenAIApiDeploymentName = azureApiDeploymentName; if (this.azureOpenAIApiKey || this.azureADTokenProvider) { if (!this.azureOpenAIApiInstanceName && !this.azureOpenAIBasePath) { throw new Error("Azure OpenAI API instance name not found"); } if (!this.azureOpenAIApiDeploymentName) { throw new Error("Azure OpenAI API deployment name not found"); } if (!this.azureOpenAIApiVersion) { throw new Error("Azure OpenAI API version not found"); } apiKey = apiKey ?? ""; } this.clientConfig = { apiKey, organization: this.organization, baseURL: configuration?.basePath, dangerouslyAllowBrowser: true, defaultHeaders: configuration?.baseOptions?.headers, defaultQuery: configuration?.baseOptions?.params, ...configuration, ...fields?.configuration, }; } /** * Method to generate embeddings for an array of documents. Splits the * documents into batches and makes requests to the OpenAI API to generate * embeddings. * @param texts Array of documents to generate embeddings for. * @returns Promise that resolves to a 2D array of embeddings for each document. */ async embedDocuments(texts) { const batches = chunkArray(this.stripNewLines ? texts.map((t) => t.replace(/\n/g, " ")) : texts, this.batchSize); const batchRequests = batches.map((batch) => { const params = { model: this.model, input: batch, }; if (this.dimensions) { params.dimensions = this.dimensions; } return this.embeddingWithRetry(params); }); const batchResponses = await Promise.all(batchRequests); const embeddings = []; for (let i = 0; i < batchResponses.length; i += 1) { const batch = batches[i]; const { data: batchResponse } = batchResponses[i]; for (let j = 0; j < batch.length; j += 1) { embeddings.push(batchResponse[j].embedding); } } return embeddings; } /** * Method to generate an embedding for a single document. Calls the * embeddingWithRetry method with the document as the input. * @param text Document to generate an embedding for. * @returns Promise that resolves to an embedding for the document. */ async embedQuery(text) { const params = { model: this.model, input: this.stripNewLines ? text.replace(/\n/g, " ") : text, }; if (this.dimensions) { params.dimensions = this.dimensions; } const { data } = await this.embeddingWithRetry(params); return data[0].embedding; } /** * Private method to make a request to the OpenAI API to generate * embeddings. Handles the retry logic and returns the response from the * API. * @param request Request to send to the OpenAI API. * @returns Promise that resolves to the response from the API. */ async embeddingWithRetry(request) { if (!this.client) { const openAIEndpointConfig = { azureOpenAIApiDeploymentName: this.azureOpenAIApiDeploymentName, azureOpenAIApiInstanceName: this.azureOpenAIApiInstanceName, azureOpenAIApiKey: this.azureOpenAIApiKey, azureOpenAIBasePath: this.azureOpenAIBasePath, baseURL: this.clientConfig.baseURL, }; const endpoint = getEndpoint(openAIEndpointConfig); const params = { ...this.clientConfig, baseURL: endpoint, timeout: this.timeout, maxRetries: 0, }; if (!params.baseURL) { delete params.baseURL; } this.client = new OpenAIClient(params); } const requestOptions = {}; if (this.azureOpenAIApiKey) { requestOptions.headers = { "api-key": this.azureOpenAIApiKey, ...requestOptions.headers, }; requestOptions.query = { "api-version": this.azureOpenAIApiVersion, ...requestOptions.query, }; } return this.caller.call(async () => { try { const res = await this.client.embeddings.create(request, requestOptions); return res; } catch (e) { const error = wrapOpenAIClientError(e); throw error; } }); } }