354 lines
15 KiB
JavaScript
354 lines
15 KiB
JavaScript
|
/* eslint-disable prefer-template */
|
||
|
import { v4 as uuidv4 } from "uuid";
|
||
|
import { VectorStore, } from "@langchain/core/vectorstores";
|
||
|
import { Document } from "@langchain/core/documents";
|
||
|
import { maximalMarginalRelevance } from "@langchain/core/utils/math";
|
||
|
import { CassandraTable, } from "../utils/cassandra.js";
|
||
|
/**
|
||
|
* Class for interacting with the Cassandra database. It extends the
|
||
|
* VectorStore class and provides methods for adding vectors and
|
||
|
* documents, searching for similar vectors, and creating instances from
|
||
|
* texts or documents.
|
||
|
*/
|
||
|
export class CassandraStore extends VectorStore {
|
||
|
_vectorstoreType() {
|
||
|
return "cassandra";
|
||
|
}
|
||
|
_cleanArgs(args) {
|
||
|
const { table, dimensions, primaryKey, nonKeyColumns, indices, metadataColumns, vectorType = "cosine", } = args;
|
||
|
if (!table || !dimensions) {
|
||
|
throw new Error("Missing required arguments");
|
||
|
}
|
||
|
// Utility function to ensure the argument is treated as an array
|
||
|
function _toArray(value) {
|
||
|
return Array.isArray(value) ? value : [value];
|
||
|
}
|
||
|
const indicesArg = indices || [];
|
||
|
// Use the primary key if provided, else default to a single auto-generated UUID column
|
||
|
let primaryKeyArg;
|
||
|
if (primaryKey) {
|
||
|
primaryKeyArg = _toArray(primaryKey);
|
||
|
}
|
||
|
else {
|
||
|
primaryKeyArg = [
|
||
|
{ name: this.idColumnAutoName, type: "uuid", partition: true },
|
||
|
];
|
||
|
}
|
||
|
// The combined nonKeyColumns and metadataColumns, de-duped by name
|
||
|
const combinedColumns = [
|
||
|
..._toArray(nonKeyColumns || []),
|
||
|
..._toArray(metadataColumns || []),
|
||
|
];
|
||
|
const deduplicatedColumns = combinedColumns.filter((col, index, self) => self.findIndex((c) => c.name === col.name) === index);
|
||
|
const nonKeyColumnsArg = [...deduplicatedColumns];
|
||
|
// If no metadata columns are specified, add a default metadata column consistent with Langchain Python
|
||
|
if (nonKeyColumnsArg.length === 0) {
|
||
|
nonKeyColumnsArg.push({
|
||
|
name: this.metadataColumnDefaultName,
|
||
|
type: "map<text, text>",
|
||
|
});
|
||
|
indicesArg.push({
|
||
|
name: `idx_${this.metadataColumnDefaultName}_${table}_keys`,
|
||
|
value: `(keys(${this.metadataColumnDefaultName}))`,
|
||
|
});
|
||
|
indicesArg.push({
|
||
|
name: `idx_${this.metadataColumnDefaultName}_${table}_entries`,
|
||
|
value: `(entries(${this.metadataColumnDefaultName}))`,
|
||
|
});
|
||
|
}
|
||
|
const addDefaultNonKeyColumnIfNeeded = (defaultColumn) => {
|
||
|
const column = nonKeyColumnsArg.find((col) => col.name === defaultColumn.name);
|
||
|
if (!column) {
|
||
|
nonKeyColumnsArg.push(defaultColumn);
|
||
|
}
|
||
|
};
|
||
|
addDefaultNonKeyColumnIfNeeded({ name: this.textColumnName, type: "text" });
|
||
|
addDefaultNonKeyColumnIfNeeded({
|
||
|
name: this.vectorColumnName,
|
||
|
type: `VECTOR<FLOAT,${dimensions}>`,
|
||
|
alias: this.embeddingColumnAlias,
|
||
|
});
|
||
|
// If no index is specified for the vector column, add a default index
|
||
|
if (!indicesArg.some((index) => new RegExp(`\\(\\s*${this.vectorColumnName.toLowerCase()}\\s*\\)`).test(index.value.toLowerCase()))) {
|
||
|
indicesArg.push({
|
||
|
name: `idx_${this.vectorColumnName}_${table}`,
|
||
|
value: `(${this.vectorColumnName})`,
|
||
|
options: `{'similarity_function': '${vectorType.toLowerCase()}'}`,
|
||
|
});
|
||
|
}
|
||
|
// Metadata the user will see excludes vector column and text column
|
||
|
const metadataColumnsArg = [...primaryKeyArg, ...nonKeyColumnsArg].filter((column) => column.name !== this.vectorColumnName &&
|
||
|
column.name !== this.textColumnName);
|
||
|
return {
|
||
|
...args,
|
||
|
vectorType,
|
||
|
primaryKey: primaryKeyArg,
|
||
|
nonKeyColumns: nonKeyColumnsArg,
|
||
|
metadataColumns: metadataColumnsArg,
|
||
|
indices: indicesArg,
|
||
|
};
|
||
|
}
|
||
|
_getColumnByName(columns, columnName) {
|
||
|
const columnsArray = Array.isArray(columns) ? columns : [columns];
|
||
|
const column = columnsArray.find((col) => col.name === columnName);
|
||
|
if (!column) {
|
||
|
throw new Error(`Column ${columnName} not found`);
|
||
|
}
|
||
|
return column;
|
||
|
}
|
||
|
constructor(embeddings, args) {
|
||
|
super(embeddings, args);
|
||
|
Object.defineProperty(this, "table", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "idColumnAutoName", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "id"
|
||
|
});
|
||
|
Object.defineProperty(this, "idColumnAutoGenerated", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "vectorColumnName", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "vector"
|
||
|
});
|
||
|
Object.defineProperty(this, "vectorColumn", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "textColumnName", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "text"
|
||
|
});
|
||
|
Object.defineProperty(this, "textColumn", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "metadataColumnDefaultName", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "metadata"
|
||
|
});
|
||
|
Object.defineProperty(this, "metadataColumns", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "similarityColumn", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "embeddingColumnAlias", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "embedding"
|
||
|
});
|
||
|
const cleanedArgs = this._cleanArgs(args);
|
||
|
// This check here to help the compiler understand that nonKeyColumns will always
|
||
|
// have values after the _cleanArgs call. It is the cleanest way to handle the fact
|
||
|
// that the compiler is not able to make this determination, no matter how hard we try!
|
||
|
if (!cleanedArgs.nonKeyColumns || cleanedArgs.nonKeyColumns.length === 0) {
|
||
|
throw new Error("No non-key columns provided");
|
||
|
}
|
||
|
this.vectorColumn = this._getColumnByName(cleanedArgs.nonKeyColumns, this.vectorColumnName);
|
||
|
this.textColumn = this._getColumnByName(cleanedArgs.nonKeyColumns, this.textColumnName);
|
||
|
this.similarityColumn = {
|
||
|
name: `similarity_${cleanedArgs.vectorType}(${this.vectorColumnName},?)`,
|
||
|
alias: "similarity_score",
|
||
|
type: "",
|
||
|
};
|
||
|
this.idColumnAutoGenerated = !args.primaryKey;
|
||
|
this.metadataColumns = cleanedArgs.metadataColumns;
|
||
|
this.table = new CassandraTable(cleanedArgs);
|
||
|
}
|
||
|
/**
|
||
|
* Method to save vectors to the Cassandra database.
|
||
|
* @param vectors Vectors to save.
|
||
|
* @param documents The documents associated with the vectors.
|
||
|
* @returns Promise that resolves when the vectors have been added.
|
||
|
*/
|
||
|
async addVectors(vectors, documents) {
|
||
|
if (vectors.length === 0) {
|
||
|
return;
|
||
|
}
|
||
|
// Prepare the values for upsert
|
||
|
const values = vectors.map((vector, index) => {
|
||
|
const document = documents[index];
|
||
|
const docMetadata = document.metadata || {};
|
||
|
// If idColumnAutoGenerated is true and ID is not provided, generate a UUID
|
||
|
if (this.idColumnAutoGenerated &&
|
||
|
(docMetadata[this.idColumnAutoName] === undefined ||
|
||
|
docMetadata[this.idColumnAutoName] === "")) {
|
||
|
docMetadata[this.idColumnAutoName] = uuidv4();
|
||
|
}
|
||
|
// Construct the row
|
||
|
const row = [];
|
||
|
// Add values for each metadata column
|
||
|
this.metadataColumns.forEach((col) => {
|
||
|
row.push(docMetadata[col.name] || null);
|
||
|
});
|
||
|
// Add the text content and vector
|
||
|
row.push(document.pageContent);
|
||
|
row.push(new Float32Array(vector));
|
||
|
return row;
|
||
|
});
|
||
|
const columns = [
|
||
|
...this.metadataColumns,
|
||
|
{ name: this.textColumnName, type: "" },
|
||
|
{ name: this.vectorColumnName, type: "" },
|
||
|
];
|
||
|
return this.table.upsert(values, columns);
|
||
|
}
|
||
|
getCassandraTable() {
|
||
|
return this.table;
|
||
|
}
|
||
|
/**
|
||
|
* Method to add documents to the Cassandra database.
|
||
|
* @param documents The documents to add.
|
||
|
* @returns Promise that resolves when the documents have been added.
|
||
|
*/
|
||
|
async addDocuments(documents) {
|
||
|
return this.addVectors(await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), documents);
|
||
|
}
|
||
|
/**
|
||
|
* Helper method to search for vectors that are similar to a given query vector.
|
||
|
* @param query The query vector.
|
||
|
* @param k The number of similar Documents to return.
|
||
|
* @param filter Optional filter to be applied as a WHERE clause.
|
||
|
* @param includeEmbedding Whether to include the embedding vectors in the results.
|
||
|
* @returns Promise that resolves with an array of tuples, each containing a Document and a score.
|
||
|
*/
|
||
|
async search(query, k, filter, includeEmbedding) {
|
||
|
const vectorAsFloat32Array = new Float32Array(query);
|
||
|
const similarityColumnWithBinds = {
|
||
|
...this.similarityColumn,
|
||
|
binds: [vectorAsFloat32Array],
|
||
|
};
|
||
|
const queryCols = [
|
||
|
...this.metadataColumns,
|
||
|
this.textColumn,
|
||
|
similarityColumnWithBinds,
|
||
|
];
|
||
|
if (includeEmbedding) {
|
||
|
queryCols.push(this.vectorColumn);
|
||
|
}
|
||
|
const orderBy = {
|
||
|
name: this.vectorColumnName,
|
||
|
operator: "ANN OF",
|
||
|
value: [vectorAsFloat32Array],
|
||
|
};
|
||
|
const queryResultSet = await this.table.select(queryCols, filter, [orderBy], k);
|
||
|
return queryResultSet?.rows.map((row) => {
|
||
|
const textContent = row[this.textColumnName];
|
||
|
const sanitizedRow = { ...row };
|
||
|
delete sanitizedRow[this.textColumnName];
|
||
|
delete sanitizedRow.similarity_score;
|
||
|
Object.keys(sanitizedRow).forEach((key) => {
|
||
|
if (sanitizedRow[key] === null) {
|
||
|
delete sanitizedRow[key];
|
||
|
}
|
||
|
});
|
||
|
return [
|
||
|
new Document({ pageContent: textContent, metadata: sanitizedRow }),
|
||
|
row.similarity_score,
|
||
|
];
|
||
|
});
|
||
|
}
|
||
|
/**
|
||
|
* Method to search for vectors that are similar to a given query vector.
|
||
|
* @param query The query vector.
|
||
|
* @param k The number of similar Documents to return.
|
||
|
* @param filter Optional filter to be applied as a WHERE clause.
|
||
|
* @returns Promise that resolves with an array of tuples, each containing a Document and a score.
|
||
|
*/
|
||
|
async similaritySearchVectorWithScore(query, k, filter) {
|
||
|
return this.search(query, k, filter, false);
|
||
|
}
|
||
|
/**
|
||
|
* Method to search for vectors that are similar to a given query vector, but with
|
||
|
* the results selected using the maximal marginal relevance.
|
||
|
* @param query The query string.
|
||
|
* @param options.k The number of similar Documents to return.
|
||
|
* @param options.fetchK=4*k The number of records to fetch before passing to the MMR algorithm.
|
||
|
* @param options.lambda=0.5 The degree of diversity among the results between 0 (maximum diversity) and 1 (minimum diversity).
|
||
|
* @param options.filter Optional filter to be applied as a WHERE clause.
|
||
|
* @returns List of documents selected by maximal marginal relevance.
|
||
|
*/
|
||
|
async maxMarginalRelevanceSearch(query, options) {
|
||
|
const { k, fetchK = 4 * k, lambda = 0.5, filter } = options;
|
||
|
const queryEmbedding = await this.embeddings.embedQuery(query);
|
||
|
const queryResults = await this.search(queryEmbedding, fetchK, filter, true);
|
||
|
const embeddingList = queryResults.map((doc) => doc[0].metadata[this.embeddingColumnAlias]);
|
||
|
const mmrIndexes = maximalMarginalRelevance(queryEmbedding, embeddingList, lambda, k);
|
||
|
return mmrIndexes.map((idx) => {
|
||
|
const doc = queryResults[idx][0];
|
||
|
delete doc.metadata[this.embeddingColumnAlias];
|
||
|
return doc;
|
||
|
});
|
||
|
}
|
||
|
/**
|
||
|
* Static method to create an instance of CassandraStore from texts.
|
||
|
* @param texts The texts to use.
|
||
|
* @param metadatas The metadata associated with the texts.
|
||
|
* @param embeddings The embeddings to use.
|
||
|
* @param args The arguments for the CassandraStore.
|
||
|
* @returns Promise that resolves with a new instance of CassandraStore.
|
||
|
*/
|
||
|
static async fromTexts(texts, metadatas, embeddings, args) {
|
||
|
const docs = [];
|
||
|
for (let index = 0; index < texts.length; index += 1) {
|
||
|
const metadata = Array.isArray(metadatas) ? metadatas[index] : metadatas;
|
||
|
const doc = new Document({
|
||
|
pageContent: texts[index],
|
||
|
metadata,
|
||
|
});
|
||
|
docs.push(doc);
|
||
|
}
|
||
|
return CassandraStore.fromDocuments(docs, embeddings, args);
|
||
|
}
|
||
|
/**
|
||
|
* Static method to create an instance of CassandraStore from documents.
|
||
|
* @param docs The documents to use.
|
||
|
* @param embeddings The embeddings to use.
|
||
|
* @param args The arguments for the CassandraStore.
|
||
|
* @returns Promise that resolves with a new instance of CassandraStore.
|
||
|
*/
|
||
|
static async fromDocuments(docs, embeddings, args) {
|
||
|
const instance = new this(embeddings, args);
|
||
|
await instance.addDocuments(docs);
|
||
|
return instance;
|
||
|
}
|
||
|
/**
|
||
|
* Static method to create an instance of CassandraStore from an existing
|
||
|
* index.
|
||
|
* @param embeddings The embeddings to use.
|
||
|
* @param args The arguments for the CassandraStore.
|
||
|
* @returns Promise that resolves with a new instance of CassandraStore.
|
||
|
*/
|
||
|
static async fromExistingIndex(embeddings, args) {
|
||
|
const instance = new this(embeddings, args);
|
||
|
return instance;
|
||
|
}
|
||
|
}
|