189 lines
6.8 KiB
JavaScript
189 lines
6.8 KiB
JavaScript
/* eslint-disable no-param-reassign */
|
|
import { getEnvironmentVariable } from "@langchain/core/utils/env";
|
|
import { OpenAI as OpenAIClient } from "openai";
|
|
import { Tool } from "@langchain/core/tools";
|
|
/**
|
|
* A tool for generating images with Open AIs Dall-E 2 or 3 API.
|
|
*/
|
|
export class DallEAPIWrapper extends Tool {
|
|
static lc_name() {
|
|
return "DallEAPIWrapper";
|
|
}
|
|
constructor(fields) {
|
|
// Shim for new base tool param name
|
|
if (fields?.responseFormat !== undefined &&
|
|
["url", "b64_json"].includes(fields.responseFormat)) {
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
fields.dallEResponseFormat = fields.responseFormat;
|
|
fields.responseFormat = "content";
|
|
}
|
|
super(fields);
|
|
Object.defineProperty(this, "name", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "dalle_api_wrapper"
|
|
});
|
|
Object.defineProperty(this, "description", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "A wrapper around OpenAI DALL-E API. Useful for when you need to generate images from a text description. Input should be an image description."
|
|
});
|
|
Object.defineProperty(this, "client", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "model", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "dall-e-3"
|
|
});
|
|
Object.defineProperty(this, "style", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "vivid"
|
|
});
|
|
Object.defineProperty(this, "quality", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "standard"
|
|
});
|
|
Object.defineProperty(this, "n", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: 1
|
|
});
|
|
Object.defineProperty(this, "size", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "1024x1024"
|
|
});
|
|
Object.defineProperty(this, "dallEResponseFormat", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "url"
|
|
});
|
|
Object.defineProperty(this, "user", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
const openAIApiKey = fields?.apiKey ??
|
|
fields?.openAIApiKey ??
|
|
getEnvironmentVariable("OPENAI_API_KEY");
|
|
const organization = fields?.organization ?? getEnvironmentVariable("OPENAI_ORGANIZATION");
|
|
const clientConfig = {
|
|
apiKey: openAIApiKey,
|
|
organization,
|
|
dangerouslyAllowBrowser: true,
|
|
baseUrl: fields?.baseUrl,
|
|
};
|
|
this.client = new OpenAIClient(clientConfig);
|
|
this.model = fields?.model ?? fields?.modelName ?? this.model;
|
|
this.style = fields?.style ?? this.style;
|
|
this.quality = fields?.quality ?? this.quality;
|
|
this.n = fields?.n ?? this.n;
|
|
this.size = fields?.size ?? this.size;
|
|
this.dallEResponseFormat =
|
|
fields?.dallEResponseFormat ?? this.dallEResponseFormat;
|
|
this.user = fields?.user;
|
|
}
|
|
/**
|
|
* Processes the API response if multiple images are generated.
|
|
* Returns a list of MessageContentImageUrl objects. If the response
|
|
* format is `url`, then the `image_url` field will contain the URL.
|
|
* If it is `b64_json`, then the `image_url` field will contain an object
|
|
* with a `url` field with the base64 encoded image.
|
|
*
|
|
* @param {OpenAIClient.Images.ImagesResponse[]} response The API response
|
|
* @returns {MessageContentImageUrl[]}
|
|
*/
|
|
processMultipleGeneratedUrls(response) {
|
|
if (this.dallEResponseFormat === "url") {
|
|
return response.flatMap((res) => {
|
|
const imageUrlContent = res.data
|
|
.flatMap((item) => {
|
|
if (!item.url)
|
|
return [];
|
|
return {
|
|
type: "image_url",
|
|
image_url: item.url,
|
|
};
|
|
})
|
|
.filter((item) => item !== undefined &&
|
|
item.type === "image_url" &&
|
|
typeof item.image_url === "string" &&
|
|
item.image_url !== undefined);
|
|
return imageUrlContent;
|
|
});
|
|
}
|
|
else {
|
|
return response.flatMap((res) => {
|
|
const b64Content = res.data
|
|
.flatMap((item) => {
|
|
if (!item.b64_json)
|
|
return [];
|
|
return {
|
|
type: "image_url",
|
|
image_url: {
|
|
url: item.b64_json,
|
|
},
|
|
};
|
|
})
|
|
.filter((item) => item !== undefined &&
|
|
item.type === "image_url" &&
|
|
typeof item.image_url === "object" &&
|
|
"url" in item.image_url &&
|
|
typeof item.image_url.url === "string" &&
|
|
item.image_url.url !== undefined);
|
|
return b64Content;
|
|
});
|
|
}
|
|
}
|
|
/** @ignore */
|
|
async _call(input) {
|
|
const generateImageFields = {
|
|
model: this.model,
|
|
prompt: input,
|
|
n: 1,
|
|
size: this.size,
|
|
response_format: this.dallEResponseFormat,
|
|
style: this.style,
|
|
quality: this.quality,
|
|
user: this.user,
|
|
};
|
|
if (this.n > 1) {
|
|
const results = await Promise.all(Array.from({ length: this.n }).map(() => this.client.images.generate(generateImageFields)));
|
|
return this.processMultipleGeneratedUrls(results);
|
|
}
|
|
const response = await this.client.images.generate(generateImageFields);
|
|
let data = "";
|
|
if (this.dallEResponseFormat === "url") {
|
|
[data] = response.data
|
|
.map((item) => item.url)
|
|
.filter((url) => url !== "undefined");
|
|
}
|
|
else {
|
|
[data] = response.data
|
|
.map((item) => item.b64_json)
|
|
.filter((b64_json) => b64_json !== "undefined");
|
|
}
|
|
return data;
|
|
}
|
|
}
|
|
Object.defineProperty(DallEAPIWrapper, "toolName", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "dalle_api_wrapper"
|
|
});
|