agsamantha/node_modules/langchain/dist/chains/graph_qa/cypher.js
2024-10-02 15:15:21 -05:00

163 lines
5.6 KiB
JavaScript

import { LLMChain } from "../llm_chain.js";
import { BaseChain } from "../base.js";
import { CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT } from "./prompts.js";
import { logVersion020MigrationWarning } from "../../util/entrypoint_deprecation.js";
/* #__PURE__ */ logVersion020MigrationWarning({
oldEntrypointName: "chains/graph_qa/cypher",
newPackageName: "@langchain/community",
});
export const INTERMEDIATE_STEPS_KEY = "intermediateSteps";
/**
* @example
* ```typescript
* const chain = new GraphCypherQAChain({
* llm: new ChatOpenAI({ temperature: 0 }),
* graph: new Neo4jGraph(),
* });
* const res = await chain.invoke("Who played in Pulp Fiction?");
* ```
*/
export class GraphCypherQAChain extends BaseChain {
constructor(props) {
super(props);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Object.defineProperty(this, "graph", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "cypherGenerationChain", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "qaChain", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "inputKey", {
enumerable: true,
configurable: true,
writable: true,
value: "query"
});
Object.defineProperty(this, "outputKey", {
enumerable: true,
configurable: true,
writable: true,
value: "result"
});
Object.defineProperty(this, "topK", {
enumerable: true,
configurable: true,
writable: true,
value: 10
});
Object.defineProperty(this, "returnDirect", {
enumerable: true,
configurable: true,
writable: true,
value: false
});
Object.defineProperty(this, "returnIntermediateSteps", {
enumerable: true,
configurable: true,
writable: true,
value: false
});
const { graph, cypherGenerationChain, qaChain, inputKey, outputKey, topK, returnIntermediateSteps, returnDirect, } = props;
this.graph = graph;
this.cypherGenerationChain = cypherGenerationChain;
this.qaChain = qaChain;
if (inputKey) {
this.inputKey = inputKey;
}
if (outputKey) {
this.outputKey = outputKey;
}
if (topK) {
this.topK = topK;
}
if (returnIntermediateSteps) {
this.returnIntermediateSteps = returnIntermediateSteps;
}
if (returnDirect) {
this.returnDirect = returnDirect;
}
}
_chainType() {
return "graph_cypher_chain";
}
get inputKeys() {
return [this.inputKey];
}
get outputKeys() {
return [this.outputKey];
}
static fromLLM(props) {
const { graph, qaPrompt = CYPHER_QA_PROMPT, cypherPrompt = CYPHER_GENERATION_PROMPT, llm, cypherLLM, qaLLM, returnIntermediateSteps = false, returnDirect = false, } = props;
if (!cypherLLM && !llm) {
throw new Error("Either 'llm' or 'cypherLLM' parameters must be provided");
}
if (!qaLLM && !llm) {
throw new Error("Either 'llm' or 'qaLLM' parameters must be provided");
}
if (cypherLLM && qaLLM && llm) {
throw new Error("You can specify up to two of 'cypherLLM', 'qaLLM', and 'llm', but not all three simultaneously.");
}
const qaChain = new LLMChain({
llm: (qaLLM || llm),
prompt: qaPrompt,
});
const cypherGenerationChain = new LLMChain({
llm: (cypherLLM || llm),
prompt: cypherPrompt,
});
return new GraphCypherQAChain({
cypherGenerationChain,
qaChain,
graph,
returnIntermediateSteps,
returnDirect,
});
}
extractCypher(text) {
const pattern = /```(.*?)```/s;
const matches = text.match(pattern);
return matches ? matches[1] : text;
}
async _call(values, runManager) {
const callbacks = runManager?.getChild();
const question = values[this.inputKey];
const intermediateSteps = [];
const generatedCypher = await this.cypherGenerationChain.call({ question, schema: this.graph.getSchema() }, callbacks);
const extractedCypher = this.extractCypher(generatedCypher.text);
await runManager?.handleText(`Generated Cypher:\n`);
await runManager?.handleText(`${extractedCypher} green\n`);
intermediateSteps.push({ query: extractedCypher });
let chainResult;
const context = await this.graph.query(extractedCypher, {
topK: this.topK,
});
if (this.returnDirect) {
chainResult = { [this.outputKey]: context };
}
else {
await runManager?.handleText("Full Context:\n");
await runManager?.handleText(`${context} green\n`);
intermediateSteps.push({ context });
const result = await this.qaChain.call({ question, context: JSON.stringify(context) }, callbacks);
chainResult = {
[this.outputKey]: result[this.qaChain.outputKey],
};
}
if (this.returnIntermediateSteps) {
chainResult[INTERMEDIATE_STEPS_KEY] = intermediateSteps;
}
return chainResult;
}
}