agsamantha/node_modules/@langchain/openai/dist/tools/dalle.js
2024-10-02 15:15:21 -05:00

189 lines
6.8 KiB
JavaScript

/* eslint-disable no-param-reassign */
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { OpenAI as OpenAIClient } from "openai";
import { Tool } from "@langchain/core/tools";
/**
* A tool for generating images with Open AIs Dall-E 2 or 3 API.
*/
export class DallEAPIWrapper extends Tool {
static lc_name() {
return "DallEAPIWrapper";
}
constructor(fields) {
// Shim for new base tool param name
if (fields?.responseFormat !== undefined &&
["url", "b64_json"].includes(fields.responseFormat)) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
fields.dallEResponseFormat = fields.responseFormat;
fields.responseFormat = "content";
}
super(fields);
Object.defineProperty(this, "name", {
enumerable: true,
configurable: true,
writable: true,
value: "dalle_api_wrapper"
});
Object.defineProperty(this, "description", {
enumerable: true,
configurable: true,
writable: true,
value: "A wrapper around OpenAI DALL-E API. Useful for when you need to generate images from a text description. Input should be an image description."
});
Object.defineProperty(this, "client", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "model", {
enumerable: true,
configurable: true,
writable: true,
value: "dall-e-3"
});
Object.defineProperty(this, "style", {
enumerable: true,
configurable: true,
writable: true,
value: "vivid"
});
Object.defineProperty(this, "quality", {
enumerable: true,
configurable: true,
writable: true,
value: "standard"
});
Object.defineProperty(this, "n", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "size", {
enumerable: true,
configurable: true,
writable: true,
value: "1024x1024"
});
Object.defineProperty(this, "dallEResponseFormat", {
enumerable: true,
configurable: true,
writable: true,
value: "url"
});
Object.defineProperty(this, "user", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
const openAIApiKey = fields?.apiKey ??
fields?.openAIApiKey ??
getEnvironmentVariable("OPENAI_API_KEY");
const organization = fields?.organization ?? getEnvironmentVariable("OPENAI_ORGANIZATION");
const clientConfig = {
apiKey: openAIApiKey,
organization,
dangerouslyAllowBrowser: true,
baseUrl: fields?.baseUrl,
};
this.client = new OpenAIClient(clientConfig);
this.model = fields?.model ?? fields?.modelName ?? this.model;
this.style = fields?.style ?? this.style;
this.quality = fields?.quality ?? this.quality;
this.n = fields?.n ?? this.n;
this.size = fields?.size ?? this.size;
this.dallEResponseFormat =
fields?.dallEResponseFormat ?? this.dallEResponseFormat;
this.user = fields?.user;
}
/**
* Processes the API response if multiple images are generated.
* Returns a list of MessageContentImageUrl objects. If the response
* format is `url`, then the `image_url` field will contain the URL.
* If it is `b64_json`, then the `image_url` field will contain an object
* with a `url` field with the base64 encoded image.
*
* @param {OpenAIClient.Images.ImagesResponse[]} response The API response
* @returns {MessageContentImageUrl[]}
*/
processMultipleGeneratedUrls(response) {
if (this.dallEResponseFormat === "url") {
return response.flatMap((res) => {
const imageUrlContent = res.data
.flatMap((item) => {
if (!item.url)
return [];
return {
type: "image_url",
image_url: item.url,
};
})
.filter((item) => item !== undefined &&
item.type === "image_url" &&
typeof item.image_url === "string" &&
item.image_url !== undefined);
return imageUrlContent;
});
}
else {
return response.flatMap((res) => {
const b64Content = res.data
.flatMap((item) => {
if (!item.b64_json)
return [];
return {
type: "image_url",
image_url: {
url: item.b64_json,
},
};
})
.filter((item) => item !== undefined &&
item.type === "image_url" &&
typeof item.image_url === "object" &&
"url" in item.image_url &&
typeof item.image_url.url === "string" &&
item.image_url.url !== undefined);
return b64Content;
});
}
}
/** @ignore */
async _call(input) {
const generateImageFields = {
model: this.model,
prompt: input,
n: 1,
size: this.size,
response_format: this.dallEResponseFormat,
style: this.style,
quality: this.quality,
user: this.user,
};
if (this.n > 1) {
const results = await Promise.all(Array.from({ length: this.n }).map(() => this.client.images.generate(generateImageFields)));
return this.processMultipleGeneratedUrls(results);
}
const response = await this.client.images.generate(generateImageFields);
let data = "";
if (this.dallEResponseFormat === "url") {
[data] = response.data
.map((item) => item.url)
.filter((url) => url !== "undefined");
}
else {
[data] = response.data
.map((item) => item.b64_json)
.filter((b64_json) => b64_json !== "undefined");
}
return data;
}
}
Object.defineProperty(DallEAPIWrapper, "toolName", {
enumerable: true,
configurable: true,
writable: true,
value: "dalle_api_wrapper"
});