agsamantha/node_modules/@langchain/community/dist/vectorstores/googlevertexai.js

537 lines
20 KiB
JavaScript
Raw Normal View History

2024-10-02 15:15:21 -05:00
import * as uuid from "uuid";
import flatten from "flat";
import { GoogleAuth } from "google-auth-library";
import { VectorStore } from "@langchain/core/vectorstores";
import { Document } from "@langchain/core/documents";
import { AsyncCaller, } from "@langchain/core/utils/async_caller";
import { GoogleVertexAIConnection } from "../utils/googlevertexai-connection.js";
/**
* A Document that optionally includes the ID of the document.
*/
export class IdDocument extends Document {
constructor(fields) {
super(fields);
Object.defineProperty(this, "id", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.id = fields.id;
}
}
class IndexEndpointConnection extends GoogleVertexAIConnection {
constructor(fields, caller) {
super(fields, caller, new GoogleAuth(fields.authOptions));
Object.defineProperty(this, "indexEndpoint", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.indexEndpoint = fields.indexEndpoint;
}
async buildUrl() {
const projectId = await this.client.getProjectId();
const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexEndpoints/${this.indexEndpoint}`;
return url;
}
buildMethod() {
return "GET";
}
async request(options) {
return this._request(undefined, options);
}
}
class RemoveDatapointConnection extends GoogleVertexAIConnection {
constructor(fields, caller) {
super(fields, caller, new GoogleAuth(fields.authOptions));
Object.defineProperty(this, "index", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.index = fields.index;
}
async buildUrl() {
const projectId = await this.client.getProjectId();
const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexes/${this.index}:removeDatapoints`;
return url;
}
buildMethod() {
return "POST";
}
async request(datapointIds, options) {
const data = {
datapointIds,
};
return this._request(data, options);
}
}
class UpsertDatapointConnection extends GoogleVertexAIConnection {
constructor(fields, caller) {
super(fields, caller, new GoogleAuth(fields.authOptions));
Object.defineProperty(this, "index", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.index = fields.index;
}
async buildUrl() {
const projectId = await this.client.getProjectId();
const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexes/${this.index}:upsertDatapoints`;
return url;
}
buildMethod() {
return "POST";
}
async request(datapoints, options) {
const data = {
datapoints,
};
return this._request(data, options);
}
}
class FindNeighborsConnection extends GoogleVertexAIConnection {
constructor(params, caller) {
super(params, caller, new GoogleAuth(params.authOptions));
Object.defineProperty(this, "indexEndpoint", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "deployedIndexId", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.indexEndpoint = params.indexEndpoint;
this.deployedIndexId = params.deployedIndexId;
}
async buildUrl() {
const projectId = await this.client.getProjectId();
const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/indexEndpoints/${this.indexEndpoint}:findNeighbors`;
return url;
}
buildMethod() {
return "POST";
}
async request(request, options) {
return this._request(request, options);
}
}
/**
* A class that represents a connection to a Google Vertex AI Matching Engine
* instance.
*/
export class MatchingEngine extends VectorStore {
constructor(embeddings, args) {
super(embeddings, args);
/**
* Docstore that retains the document, stored by ID
*/
Object.defineProperty(this, "docstore", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/**
* The host to connect to for queries and upserts.
*/
Object.defineProperty(this, "apiEndpoint", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "apiVersion", {
enumerable: true,
configurable: true,
writable: true,
value: "v1"
});
Object.defineProperty(this, "endpoint", {
enumerable: true,
configurable: true,
writable: true,
value: "us-central1-aiplatform.googleapis.com"
});
Object.defineProperty(this, "location", {
enumerable: true,
configurable: true,
writable: true,
value: "us-central1"
});
/**
* The id for the index endpoint
*/
Object.defineProperty(this, "indexEndpoint", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/**
* The id for the index
*/
Object.defineProperty(this, "index", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/**
* Explicitly set Google Auth credentials if you cannot get them from google auth application-default login
* This is useful for serverless or autoscaling environments like Fargate
*/
Object.defineProperty(this, "authOptions", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
/**
* The id for the "deployed index", which is an identifier in the
* index endpoint that references the index (but is not the index id)
*/
Object.defineProperty(this, "deployedIndexId", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "callerParams", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "callerOptions", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "caller", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "indexEndpointClient", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "removeDatapointClient", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "upsertDatapointClient", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.embeddings = embeddings;
this.docstore = args.docstore;
this.apiEndpoint = args.apiEndpoint ?? this.apiEndpoint;
this.deployedIndexId = args.deployedIndexId ?? this.deployedIndexId;
this.apiVersion = args.apiVersion ?? this.apiVersion;
this.endpoint = args.endpoint ?? this.endpoint;
this.location = args.location ?? this.location;
this.indexEndpoint = args.indexEndpoint ?? this.indexEndpoint;
this.index = args.index ?? this.index;
this.authOptions = args.authOptions ?? this.authOptions;
this.callerParams = args.callerParams ?? this.callerParams;
this.callerOptions = args.callerOptions ?? this.callerOptions;
this.caller = new AsyncCaller(this.callerParams || {});
const indexClientParams = {
endpoint: this.endpoint,
location: this.location,
apiVersion: this.apiVersion,
indexEndpoint: this.indexEndpoint,
authOptions: this.authOptions,
};
this.indexEndpointClient = new IndexEndpointConnection(indexClientParams, this.caller);
const removeClientParams = {
endpoint: this.endpoint,
location: this.location,
apiVersion: this.apiVersion,
index: this.index,
authOptions: this.authOptions,
};
this.removeDatapointClient = new RemoveDatapointConnection(removeClientParams, this.caller);
const upsertClientParams = {
endpoint: this.endpoint,
location: this.location,
apiVersion: this.apiVersion,
index: this.index,
authOptions: this.authOptions,
};
this.upsertDatapointClient = new UpsertDatapointConnection(upsertClientParams, this.caller);
}
_vectorstoreType() {
return "googlevertexai";
}
async addDocuments(documents) {
const texts = documents.map((doc) => doc.pageContent);
const vectors = await this.embeddings.embedDocuments(texts);
return this.addVectors(vectors, documents);
}
async addVectors(vectors, documents) {
if (vectors.length !== documents.length) {
throw new Error(`Vectors and metadata must have the same length`);
}
const datapoints = vectors.map((vector, idx) => this.buildDatapoint(vector, documents[idx]));
const options = {};
const response = await this.upsertDatapointClient.request(datapoints, options);
if (Object.keys(response?.data ?? {}).length === 0) {
// Nothing in the response in the body means we saved it ok
const idDoc = documents;
const docsToStore = {};
idDoc.forEach((doc) => {
if (doc.id) {
docsToStore[doc.id] = doc;
}
});
await this.docstore.add(docsToStore);
}
}
// TODO: Refactor this into a utility type and use with pinecone as well?
// eslint-disable-next-line @typescript-eslint/no-explicit-any
cleanMetadata(documentMetadata) {
function getStringArrays(prefix,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
m) {
let ret = {};
Object.keys(m).forEach((key) => {
const newPrefix = prefix.length > 0 ? `${prefix}.${key}` : key;
const val = m[key];
if (!val) {
// Ignore it
}
else if (Array.isArray(val)) {
// Make sure everything in the array is a string
ret[newPrefix] = val.map((v) => `${v}`);
}
else if (typeof val === "object") {
const subArrays = getStringArrays(newPrefix, val);
ret = { ...ret, ...subArrays };
}
});
return ret;
}
const stringArrays = getStringArrays("", documentMetadata);
const flatMetadata = flatten(documentMetadata);
Object.keys(flatMetadata).forEach((key) => {
Object.keys(stringArrays).forEach((arrayKey) => {
const matchKey = `${arrayKey}.`;
if (key.startsWith(matchKey)) {
delete flatMetadata[key];
}
});
});
const metadata = {
...flatMetadata,
...stringArrays,
};
return metadata;
}
/**
* Given the metadata from a document, convert it to an array of Restriction
* objects that may be passed to the Matching Engine and stored.
* The default implementation flattens any metadata and includes it as
* an "allowList". Subclasses can choose to convert some of these to
* "denyList" items or to add additional restrictions (for example, to format
* dates into a different structure or to add additional restrictions
* based on the date).
* @param documentMetadata - The metadata from a document
* @returns a Restriction[] (or an array of a subclass, from the FilterType)
*/
metadataToRestrictions(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
documentMetadata) {
const metadata = this.cleanMetadata(documentMetadata);
const restrictions = [];
for (const key of Object.keys(metadata)) {
// Make sure the value is an array (or that we'll ignore it)
let valArray;
const val = metadata[key];
if (val === null) {
valArray = null;
}
else if (Array.isArray(val) && val.length > 0) {
valArray = val;
}
else {
valArray = [`${val}`];
}
// Add to the restrictions if we do have a valid value
if (valArray) {
// Determine if this key is for the allowList or denyList
// TODO: get which ones should be on the deny list
const listType = "allowList";
// Create the restriction
const restriction = {
namespace: key,
[listType]: valArray,
};
// Add it to the restriction list
restrictions.push(restriction);
}
}
return restrictions;
}
/**
* Create an index datapoint for the vector and document id.
* If an id does not exist, create it and set the document to its value.
* @param vector
* @param document
*/
buildDatapoint(vector, document) {
if (!document.id) {
// eslint-disable-next-line no-param-reassign
document.id = uuid.v4();
}
const ret = {
datapointId: document.id,
featureVector: vector,
};
const restrictions = this.metadataToRestrictions(document.metadata);
if (restrictions?.length > 0) {
ret.restricts = restrictions;
}
return ret;
}
async delete(params) {
const options = {};
await this.removeDatapointClient.request(params.ids, options);
}
async similaritySearchVectorWithScore(query, k, filter) {
// Format the query into the request
const deployedIndexId = await this.getDeployedIndexId();
const requestQuery = {
neighborCount: k,
datapoint: {
datapointId: `0`,
featureVector: query,
},
};
if (filter) {
requestQuery.datapoint.restricts = filter;
}
const request = {
deployedIndexId,
queries: [requestQuery],
};
// Build the connection.
// Has to be done here, since we defer getting the endpoint until
// we need it.
const apiEndpoint = await this.getPublicAPIEndpoint();
const findNeighborsParams = {
endpoint: apiEndpoint,
indexEndpoint: this.indexEndpoint,
apiVersion: this.apiVersion,
location: this.location,
deployedIndexId,
authOptions: this.authOptions,
};
const connection = new FindNeighborsConnection(findNeighborsParams, this.caller);
// Make the call
const options = {};
const response = await connection.request(request, options);
// Get the document for each datapoint id and return them
const nearestNeighbors = response?.data?.nearestNeighbors ?? [];
const nearestNeighbor = nearestNeighbors[0];
const neighbors = nearestNeighbor?.neighbors ?? [];
const ret = await Promise.all(neighbors.map(async (neighbor) => {
const id = neighbor?.datapoint?.datapointId;
const distance = neighbor?.distance;
let doc;
try {
doc = await this.docstore.search(id);
}
catch (xx) {
// Documents that are in the index are returned, even if they
// are not in the document store, to allow for some way to get
// the id so they can be deleted.
console.error(xx);
console.warn([
`Document with id "${id}" is missing from the backing docstore.`,
`This can occur if you clear the docstore without deleting from the corresponding Matching Engine index.`,
`To resolve this, you should call .delete() with this id as part of the "ids" parameter.`,
].join("\n"));
doc = new Document({ pageContent: `Missing document ${id}` });
}
doc.id ??= id;
return [doc, distance];
}));
return ret;
}
/**
* For this index endpoint, figure out what API Endpoint URL and deployed
* index ID should be used to do upserts and queries.
* Also sets the `apiEndpoint` and `deployedIndexId` property for future use.
* @return The URL
*/
async determinePublicAPIEndpoint() {
const response = await this.indexEndpointClient.request(this.callerOptions);
// Get the endpoint
const publicEndpointDomainName = response?.data?.publicEndpointDomainName;
this.apiEndpoint = publicEndpointDomainName;
// Determine which of the deployed indexes match the index id
// and get the deployed index id. The list of deployed index ids
// contain the "index name" or path, but not the index id by itself,
// so we need to extract it from the name
const indexPathPattern = /projects\/.+\/locations\/.+\/indexes\/(.+)$/;
const deployedIndexes = response?.data?.deployedIndexes ?? [];
const deployedIndex = deployedIndexes.find((index) => {
const deployedIndexPath = index.index;
const match = deployedIndexPath.match(indexPathPattern);
if (match) {
const [, potentialIndexId] = match;
if (potentialIndexId === this.index) {
return true;
}
}
return false;
});
if (deployedIndex) {
this.deployedIndexId = deployedIndex.id;
}
return {
apiEndpoint: this.apiEndpoint,
deployedIndexId: this.deployedIndexId,
};
}
async getPublicAPIEndpoint() {
return (this.apiEndpoint ?? (await this.determinePublicAPIEndpoint()).apiEndpoint);
}
async getDeployedIndexId() {
return (this.deployedIndexId ??
(await this.determinePublicAPIEndpoint()).deployedIndexId);
}
static async fromTexts(texts, metadatas, embeddings, dbConfig) {
const docs = texts.map((text, index) => ({
pageContent: text,
metadata: Array.isArray(metadatas) ? metadatas[index] : metadatas,
}));
return this.fromDocuments(docs, embeddings, dbConfig);
}
static async fromDocuments(docs, embeddings, dbConfig) {
const ret = new MatchingEngine(embeddings, dbConfig);
await ret.addDocuments(docs);
return ret;
}
}