import { LLMChain } from "../llm_chain.js"; import { BaseChain } from "../base.js"; import { CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT } from "./prompts.js"; import { logVersion020MigrationWarning } from "../../util/entrypoint_deprecation.js"; /* #__PURE__ */ logVersion020MigrationWarning({ oldEntrypointName: "chains/graph_qa/cypher", newPackageName: "@langchain/community", }); export const INTERMEDIATE_STEPS_KEY = "intermediateSteps"; /** * @example * ```typescript * const chain = new GraphCypherQAChain({ * llm: new ChatOpenAI({ temperature: 0 }), * graph: new Neo4jGraph(), * }); * const res = await chain.invoke("Who played in Pulp Fiction?"); * ``` */ export class GraphCypherQAChain extends BaseChain { constructor(props) { super(props); // eslint-disable-next-line @typescript-eslint/no-explicit-any Object.defineProperty(this, "graph", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "cypherGenerationChain", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "qaChain", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "inputKey", { enumerable: true, configurable: true, writable: true, value: "query" }); Object.defineProperty(this, "outputKey", { enumerable: true, configurable: true, writable: true, value: "result" }); Object.defineProperty(this, "topK", { enumerable: true, configurable: true, writable: true, value: 10 }); Object.defineProperty(this, "returnDirect", { enumerable: true, configurable: true, writable: true, value: false }); Object.defineProperty(this, "returnIntermediateSteps", { enumerable: true, configurable: true, writable: true, value: false }); const { graph, cypherGenerationChain, qaChain, inputKey, outputKey, topK, returnIntermediateSteps, returnDirect, } = props; this.graph = graph; this.cypherGenerationChain = cypherGenerationChain; this.qaChain = qaChain; if (inputKey) { this.inputKey = inputKey; } if (outputKey) { this.outputKey = outputKey; } if (topK) { this.topK = topK; } if (returnIntermediateSteps) { this.returnIntermediateSteps = returnIntermediateSteps; } if (returnDirect) { this.returnDirect = returnDirect; } } _chainType() { return "graph_cypher_chain"; } get inputKeys() { return [this.inputKey]; } get outputKeys() { return [this.outputKey]; } static fromLLM(props) { const { graph, qaPrompt = CYPHER_QA_PROMPT, cypherPrompt = CYPHER_GENERATION_PROMPT, llm, cypherLLM, qaLLM, returnIntermediateSteps = false, returnDirect = false, } = props; if (!cypherLLM && !llm) { throw new Error("Either 'llm' or 'cypherLLM' parameters must be provided"); } if (!qaLLM && !llm) { throw new Error("Either 'llm' or 'qaLLM' parameters must be provided"); } if (cypherLLM && qaLLM && llm) { throw new Error("You can specify up to two of 'cypherLLM', 'qaLLM', and 'llm', but not all three simultaneously."); } const qaChain = new LLMChain({ llm: (qaLLM || llm), prompt: qaPrompt, }); const cypherGenerationChain = new LLMChain({ llm: (cypherLLM || llm), prompt: cypherPrompt, }); return new GraphCypherQAChain({ cypherGenerationChain, qaChain, graph, returnIntermediateSteps, returnDirect, }); } extractCypher(text) { const pattern = /```(.*?)```/s; const matches = text.match(pattern); return matches ? matches[1] : text; } async _call(values, runManager) { const callbacks = runManager?.getChild(); const question = values[this.inputKey]; const intermediateSteps = []; const generatedCypher = await this.cypherGenerationChain.call({ question, schema: this.graph.getSchema() }, callbacks); const extractedCypher = this.extractCypher(generatedCypher.text); await runManager?.handleText(`Generated Cypher:\n`); await runManager?.handleText(`${extractedCypher} green\n`); intermediateSteps.push({ query: extractedCypher }); let chainResult; const context = await this.graph.query(extractedCypher, { topK: this.topK, }); if (this.returnDirect) { chainResult = { [this.outputKey]: context }; } else { await runManager?.handleText("Full Context:\n"); await runManager?.handleText(`${context} green\n`); intermediateSteps.push({ context }); const result = await this.qaChain.call({ question, context: JSON.stringify(context) }, callbacks); chainResult = { [this.outputKey]: result[this.qaChain.outputKey], }; } if (this.returnIntermediateSteps) { chainResult[INTERMEDIATE_STEPS_KEY] = intermediateSteps; } return chainResult; } }