163 lines
5.6 KiB
JavaScript
163 lines
5.6 KiB
JavaScript
import { LLMChain } from "../llm_chain.js";
|
|
import { BaseChain } from "../base.js";
|
|
import { CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT } from "./prompts.js";
|
|
import { logVersion020MigrationWarning } from "../../util/entrypoint_deprecation.js";
|
|
/* #__PURE__ */ logVersion020MigrationWarning({
|
|
oldEntrypointName: "chains/graph_qa/cypher",
|
|
newPackageName: "@langchain/community",
|
|
});
|
|
export const INTERMEDIATE_STEPS_KEY = "intermediateSteps";
|
|
/**
|
|
* @example
|
|
* ```typescript
|
|
* const chain = new GraphCypherQAChain({
|
|
* llm: new ChatOpenAI({ temperature: 0 }),
|
|
* graph: new Neo4jGraph(),
|
|
* });
|
|
* const res = await chain.invoke("Who played in Pulp Fiction?");
|
|
* ```
|
|
*/
|
|
export class GraphCypherQAChain extends BaseChain {
|
|
constructor(props) {
|
|
super(props);
|
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
Object.defineProperty(this, "graph", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "cypherGenerationChain", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "qaChain", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "inputKey", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "query"
|
|
});
|
|
Object.defineProperty(this, "outputKey", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "result"
|
|
});
|
|
Object.defineProperty(this, "topK", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: 10
|
|
});
|
|
Object.defineProperty(this, "returnDirect", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: false
|
|
});
|
|
Object.defineProperty(this, "returnIntermediateSteps", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: false
|
|
});
|
|
const { graph, cypherGenerationChain, qaChain, inputKey, outputKey, topK, returnIntermediateSteps, returnDirect, } = props;
|
|
this.graph = graph;
|
|
this.cypherGenerationChain = cypherGenerationChain;
|
|
this.qaChain = qaChain;
|
|
if (inputKey) {
|
|
this.inputKey = inputKey;
|
|
}
|
|
if (outputKey) {
|
|
this.outputKey = outputKey;
|
|
}
|
|
if (topK) {
|
|
this.topK = topK;
|
|
}
|
|
if (returnIntermediateSteps) {
|
|
this.returnIntermediateSteps = returnIntermediateSteps;
|
|
}
|
|
if (returnDirect) {
|
|
this.returnDirect = returnDirect;
|
|
}
|
|
}
|
|
_chainType() {
|
|
return "graph_cypher_chain";
|
|
}
|
|
get inputKeys() {
|
|
return [this.inputKey];
|
|
}
|
|
get outputKeys() {
|
|
return [this.outputKey];
|
|
}
|
|
static fromLLM(props) {
|
|
const { graph, qaPrompt = CYPHER_QA_PROMPT, cypherPrompt = CYPHER_GENERATION_PROMPT, llm, cypherLLM, qaLLM, returnIntermediateSteps = false, returnDirect = false, } = props;
|
|
if (!cypherLLM && !llm) {
|
|
throw new Error("Either 'llm' or 'cypherLLM' parameters must be provided");
|
|
}
|
|
if (!qaLLM && !llm) {
|
|
throw new Error("Either 'llm' or 'qaLLM' parameters must be provided");
|
|
}
|
|
if (cypherLLM && qaLLM && llm) {
|
|
throw new Error("You can specify up to two of 'cypherLLM', 'qaLLM', and 'llm', but not all three simultaneously.");
|
|
}
|
|
const qaChain = new LLMChain({
|
|
llm: (qaLLM || llm),
|
|
prompt: qaPrompt,
|
|
});
|
|
const cypherGenerationChain = new LLMChain({
|
|
llm: (cypherLLM || llm),
|
|
prompt: cypherPrompt,
|
|
});
|
|
return new GraphCypherQAChain({
|
|
cypherGenerationChain,
|
|
qaChain,
|
|
graph,
|
|
returnIntermediateSteps,
|
|
returnDirect,
|
|
});
|
|
}
|
|
extractCypher(text) {
|
|
const pattern = /```(.*?)```/s;
|
|
const matches = text.match(pattern);
|
|
return matches ? matches[1] : text;
|
|
}
|
|
async _call(values, runManager) {
|
|
const callbacks = runManager?.getChild();
|
|
const question = values[this.inputKey];
|
|
const intermediateSteps = [];
|
|
const generatedCypher = await this.cypherGenerationChain.call({ question, schema: this.graph.getSchema() }, callbacks);
|
|
const extractedCypher = this.extractCypher(generatedCypher.text);
|
|
await runManager?.handleText(`Generated Cypher:\n`);
|
|
await runManager?.handleText(`${extractedCypher} green\n`);
|
|
intermediateSteps.push({ query: extractedCypher });
|
|
let chainResult;
|
|
const context = await this.graph.query(extractedCypher, {
|
|
topK: this.topK,
|
|
});
|
|
if (this.returnDirect) {
|
|
chainResult = { [this.outputKey]: context };
|
|
}
|
|
else {
|
|
await runManager?.handleText("Full Context:\n");
|
|
await runManager?.handleText(`${context} green\n`);
|
|
intermediateSteps.push({ context });
|
|
const result = await this.qaChain.call({ question, context: JSON.stringify(context) }, callbacks);
|
|
chainResult = {
|
|
[this.outputKey]: result[this.qaChain.outputKey],
|
|
};
|
|
}
|
|
if (this.returnIntermediateSteps) {
|
|
chainResult[INTERMEDIATE_STEPS_KEY] = intermediateSteps;
|
|
}
|
|
return chainResult;
|
|
}
|
|
}
|