agsamantha/node_modules/langchain/dist/experimental/openai_assistant/index.js
2024-10-02 15:15:21 -05:00

260 lines
9.5 KiB
JavaScript

import { OpenAIClient } from "@langchain/openai";
import { StructuredTool } from "@langchain/core/tools";
import { Runnable } from "@langchain/core/runnables";
import { formatToOpenAIAssistantTool } from "@langchain/openai";
import { sleep } from "../../util/time.js";
export class OpenAIAssistantRunnable extends Runnable {
constructor(fields) {
super(fields);
Object.defineProperty(this, "lc_namespace", {
enumerable: true,
configurable: true,
writable: true,
value: ["langchain", "experimental", "openai_assistant"]
});
Object.defineProperty(this, "client", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "assistantId", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "pollIntervalMs", {
enumerable: true,
configurable: true,
writable: true,
value: 1000
});
Object.defineProperty(this, "asAgent", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.client = fields.client ?? new OpenAIClient(fields?.clientOptions);
this.assistantId = fields.assistantId;
this.asAgent = fields.asAgent ?? this.asAgent;
}
static async createAssistant({ model, name, instructions, tools, client, clientOptions, asAgent, pollIntervalMs, fileIds, }) {
const formattedTools = tools?.map((tool) => {
// eslint-disable-next-line no-instanceof/no-instanceof
if (tool instanceof StructuredTool) {
return formatToOpenAIAssistantTool(tool);
}
return tool;
}) ?? [];
const oaiClient = client ?? new OpenAIClient(clientOptions);
const assistant = await oaiClient.beta.assistants.create({
name,
instructions,
tools: formattedTools,
model,
file_ids: fileIds,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
});
return new this({
client: oaiClient,
assistantId: assistant.id,
asAgent,
pollIntervalMs,
});
}
async invoke(input, _options) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let run;
if (this.asAgent && input.steps && input.steps.length > 0) {
const parsedStepsInput = await this._parseStepsInput(input);
run = await this.client.beta.threads.runs.submitToolOutputs(parsedStepsInput.threadId, parsedStepsInput.runId, {
tool_outputs: parsedStepsInput.toolOutputs,
});
}
else if (!("threadId" in input)) {
const thread = {
messages: [
{
role: "user",
content: input.content,
file_ids: input.fileIds,
metadata: input.messagesMetadata,
},
],
metadata: input.threadMetadata,
};
run = await this._createThreadAndRun({
...input,
thread,
});
}
else if (!("runId" in input)) {
await this.client.beta.threads.messages.create(input.threadId, {
content: input.content,
role: "user",
file_ids: input.file_ids,
metadata: input.messagesMetadata,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
});
run = await this._createRun(input);
}
else {
// Submitting tool outputs to an existing run, outside the AgentExecutor
// framework.
run = await this.client.beta.threads.runs.submitToolOutputs(input.threadId, input.runId, {
tool_outputs: input.toolOutputs,
});
}
return this._getResponse(run.id, run.thread_id);
}
/**
* Delete an assistant.
*
* @link {https://platform.openai.com/docs/api-reference/assistants/deleteAssistant}
* @returns {Promise<AssistantDeleted>}
*/
async deleteAssistant() {
return await this.client.beta.assistants.del(this.assistantId);
}
/**
* Retrieves an assistant.
*
* @link {https://platform.openai.com/docs/api-reference/assistants/getAssistant}
* @returns {Promise<OpenAIClient.Beta.Assistants.Assistant>}
*/
async getAssistant() {
return await this.client.beta.assistants.retrieve(this.assistantId);
}
/**
* Modifies an assistant.
*
* @link {https://platform.openai.com/docs/api-reference/assistants/modifyAssistant}
* @returns {Promise<OpenAIClient.Beta.Assistants.Assistant>}
*/
async modifyAssistant({ model, name, instructions, fileIds, }) {
return await this.client.beta.assistants.update(this.assistantId, {
name,
instructions,
model,
file_ids: fileIds,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
});
}
async _parseStepsInput(input) {
const { action: { runId, threadId }, } = input.steps[input.steps.length - 1];
const run = await this._waitForRun(runId, threadId);
const toolCalls = run.required_action?.submit_tool_outputs.tool_calls;
if (!toolCalls) {
return input;
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const toolOutputs = toolCalls.flatMap((toolCall) => {
const matchedAction = input.steps.find((step) => step.action.toolCallId === toolCall.id);
return matchedAction
? [
{
output: matchedAction.observation,
tool_call_id: matchedAction.action.toolCallId,
},
]
: [];
});
return { toolOutputs, runId, threadId };
}
async _createRun({ instructions, model, tools, metadata, threadId, }) {
const run = this.client.beta.threads.runs.create(threadId, {
assistant_id: this.assistantId,
instructions,
model,
tools,
metadata,
});
return run;
}
async _createThreadAndRun(input) {
const params = [
"instructions",
"model",
"tools",
"run_metadata",
]
.filter((key) => key in input)
.reduce((obj, key) => {
const newObj = obj;
newObj[key] = input[key];
return newObj;
}, {});
const run = this.client.beta.threads.createAndRun({
...params,
thread: input.thread,
assistant_id: this.assistantId,
});
return run;
}
async _waitForRun(runId, threadId) {
let inProgress = true;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let run = {};
while (inProgress) {
run = await this.client.beta.threads.runs.retrieve(threadId, runId);
inProgress = ["in_progress", "queued"].includes(run.status);
if (inProgress) {
await sleep(this.pollIntervalMs);
}
}
return run;
}
async _getResponse(runId, threadId) {
const run = await this._waitForRun(runId, threadId);
if (run.status === "completed") {
const messages = await this.client.beta.threads.messages.list(threadId, {
order: "desc",
});
const newMessages = messages.data.filter((msg) => msg.run_id === runId);
if (!this.asAgent) {
return newMessages;
}
const answer = newMessages.flatMap((msg) => msg.content);
if (answer.every((item) => item.type === "text")) {
const answerString = answer
.map((item) => item.type === "text" && item.text.value)
.join("\n");
return {
returnValues: {
output: answerString,
runId,
threadId,
},
log: "",
runId,
threadId,
};
}
}
else if (run.status === "requires_action") {
if (!this.asAgent) {
return run.required_action?.submit_tool_outputs.tool_calls ?? [];
}
const actions = [];
run.required_action?.submit_tool_outputs.tool_calls.forEach(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(item) => {
const functionCall = item.function;
const args = JSON.parse(functionCall.arguments);
actions.push({
tool: functionCall.name,
toolInput: args,
toolCallId: item.id,
log: "",
runId,
threadId,
});
});
return actions;
}
const runInfo = JSON.stringify(run, null, 2);
throw new Error(`Unexpected run status ${run.status}.\nFull run info:\n\n${runInfo}`);
}
}