agsamantha/node_modules/langchain/dist/agents/toolkits/sql/sql.js

73 lines
2.5 KiB
JavaScript
Raw Normal View History

2024-10-02 20:15:21 +00:00
import { BaseToolkit } from "@langchain/core/tools";
import { renderTemplate } from "@langchain/core/prompts";
import { InfoSqlTool, ListTablesSqlTool, QueryCheckerTool, QuerySqlTool, } from "../../../tools/sql.js";
import { SQL_PREFIX, SQL_SUFFIX } from "./prompt.js";
import { LLMChain } from "../../../chains/llm_chain.js";
import { ZeroShotAgent } from "../../mrkl/index.js";
import { AgentExecutor } from "../../executor.js";
/**
* Class that represents a toolkit for working with SQL databases. It
* initializes SQL tools based on the provided SQL database.
* @example
* ```typescript
* const model = new ChatOpenAI({});
* const toolkit = new SqlToolkit(sqlDb, model);
* const executor = createSqlAgent(model, toolkit);
* const result = await executor.invoke({ input: 'List the total sales per country. Which country's customers spent the most?' });
* console.log(`Got output ${result.output}`);
* ```
*/
export class SqlToolkit extends BaseToolkit {
constructor(db, llm) {
super();
Object.defineProperty(this, "tools", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "db", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "dialect", {
enumerable: true,
configurable: true,
writable: true,
value: "sqlite"
});
this.db = db;
this.tools = [
new QuerySqlTool(db),
new InfoSqlTool(db),
new ListTablesSqlTool(db),
new QueryCheckerTool({ llm }),
];
}
}
export function createSqlAgent(llm, toolkit, args) {
const { prefix = SQL_PREFIX, suffix = SQL_SUFFIX, inputVariables = ["input", "agent_scratchpad"], topK = 10, } = args ?? {};
const { tools } = toolkit;
const formattedPrefix = renderTemplate(prefix, "f-string", {
dialect: toolkit.dialect,
top_k: topK,
});
const prompt = ZeroShotAgent.createPrompt(tools, {
prefix: formattedPrefix,
suffix,
inputVariables,
});
const chain = new LLMChain({ prompt, llm });
const agent = new ZeroShotAgent({
llmChain: chain,
allowedTools: tools.map((t) => t.name),
});
return AgentExecutor.fromAgentAndTools({
agent,
tools,
returnIntermediateSteps: true,
});
}