131 lines
4.6 KiB
JavaScript
131 lines
4.6 KiB
JavaScript
"use strict";
|
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
exports.getPromptTemplateFromKey = exports.HydeRetriever = void 0;
|
|
const prompts_1 = require("@langchain/core/prompts");
|
|
const prompt_values_1 = require("@langchain/core/prompt_values");
|
|
const vectorstores_1 = require("@langchain/core/vectorstores");
|
|
/**
|
|
* A class for retrieving relevant documents based on a given query. It
|
|
* extends the VectorStoreRetriever class and uses a BaseLanguageModel to
|
|
* generate a hypothetical answer to the query, which is then used to
|
|
* retrieve relevant documents.
|
|
* @example
|
|
* ```typescript
|
|
* const retriever = new HydeRetriever({
|
|
* vectorStore: new MemoryVectorStore(new OpenAIEmbeddings()),
|
|
* llm: new ChatOpenAI(),
|
|
* k: 1,
|
|
* });
|
|
* await vectorStore.addDocuments(
|
|
* [
|
|
* "My name is John.",
|
|
* "My name is Bob.",
|
|
* "My favourite food is pizza.",
|
|
* "My favourite food is pasta.",
|
|
* ].map((pageContent) => new Document({ pageContent })),
|
|
* );
|
|
* const results = await retriever.getRelevantDocuments(
|
|
* "What is my favourite food?",
|
|
* );
|
|
* ```
|
|
*/
|
|
class HydeRetriever extends vectorstores_1.VectorStoreRetriever {
|
|
static lc_name() {
|
|
return "HydeRetriever";
|
|
}
|
|
get lc_namespace() {
|
|
return ["langchain", "retrievers", "hyde"];
|
|
}
|
|
constructor(fields) {
|
|
super(fields);
|
|
Object.defineProperty(this, "llm", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
Object.defineProperty(this, "promptTemplate", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: void 0
|
|
});
|
|
this.llm = fields.llm;
|
|
this.promptTemplate =
|
|
typeof fields.promptTemplate === "string"
|
|
? getPromptTemplateFromKey(fields.promptTemplate)
|
|
: fields.promptTemplate;
|
|
if (this.promptTemplate) {
|
|
const { inputVariables } = this.promptTemplate;
|
|
if (inputVariables.length !== 1 && inputVariables[0] !== "question") {
|
|
throw new Error(`Prompt template must accept a single input variable 'question'. Invalid input variables for prompt template: ${inputVariables}`);
|
|
}
|
|
}
|
|
}
|
|
async _getRelevantDocuments(query, runManager) {
|
|
let value = new prompt_values_1.StringPromptValue(query);
|
|
// Use a custom template if provided
|
|
if (this.promptTemplate) {
|
|
value = await this.promptTemplate.formatPromptValue({ question: query });
|
|
}
|
|
// Get a hypothetical answer from the LLM
|
|
const res = await this.llm.generatePrompt([value]);
|
|
const answer = res.generations[0][0].text;
|
|
// Retrieve relevant documents based on the hypothetical answer
|
|
const results = await this.vectorStore.similaritySearch(answer, this.k, this.filter, runManager?.getChild("vectorstore"));
|
|
return results;
|
|
}
|
|
}
|
|
exports.HydeRetriever = HydeRetriever;
|
|
/**
|
|
* Returns a BasePromptTemplate instance based on a given PromptKey.
|
|
*/
|
|
function getPromptTemplateFromKey(key) {
|
|
let template;
|
|
switch (key) {
|
|
case "websearch":
|
|
template = `Please write a passage to answer the question
|
|
Question: {question}
|
|
Passage:`;
|
|
break;
|
|
case "scifact":
|
|
template = `Please write a scientific paper passage to support/refute the claim
|
|
Claim: {question}
|
|
Passage:`;
|
|
break;
|
|
case "arguana":
|
|
template = `Please write a counter argument for the passage
|
|
Passage: {question}
|
|
Counter Argument:`;
|
|
break;
|
|
case "trec-covid":
|
|
template = `Please write a scientific paper passage to answer the question
|
|
Question: {question}
|
|
Passage:`;
|
|
break;
|
|
case "fiqa":
|
|
template = `Please write a financial article passage to answer the question
|
|
Question: {question}
|
|
Passage:`;
|
|
break;
|
|
case "dbpedia-entity":
|
|
template = `Please write a passage to answer the question.
|
|
Question: {question}
|
|
Passage:`;
|
|
break;
|
|
case "trec-news":
|
|
template = `Please write a news passage about the topic.
|
|
Topic: {question}
|
|
Passage:`;
|
|
break;
|
|
case "mr-tydi":
|
|
template = `Please write a passage in Swahili/Korean/Japanese/Bengali to answer the question in detail.
|
|
Question: {question}
|
|
Passage:`;
|
|
break;
|
|
default:
|
|
throw new Error(`Invalid prompt key: ${key}`);
|
|
}
|
|
return prompts_1.PromptTemplate.fromTemplate(template);
|
|
}
|
|
exports.getPromptTemplateFromKey = getPromptTemplateFromKey;
|