import { v4 as uuid4, validate } from "uuid"; import { Client } from "../index.js"; import { shuffle } from "../utils/shuffle.js"; import { AsyncCaller } from "../utils/async_caller.js"; import pRetry from "p-retry"; import { getCurrentRunTree, traceable } from "../traceable.js"; function isExperimentResultsList(value) { return value.some((x) => typeof x !== "string"); } async function loadExperiment(client, experiment) { const value = typeof experiment === "string" ? experiment : experiment.experimentName; return client.readProject(validate(value) ? { projectId: value } : { projectName: value }); } async function loadTraces(client, experiment, options) { const executionOrder = options.loadNested ? undefined : 1; const runs = await client.listRuns(validate(experiment) ? { projectId: experiment, executionOrder } : { projectName: experiment, executionOrder }); const treeMap = {}; const runIdMap = {}; const results = []; for await (const run of runs) { if (run.parent_run_id != null) { treeMap[run.parent_run_id] ??= []; treeMap[run.parent_run_id].push(run); } else { results.push(run); } runIdMap[run.id] = run; } for (const [parentRunId, childRuns] of Object.entries(treeMap)) { const parentRun = runIdMap[parentRunId]; parentRun.child_runs = childRuns.sort((a, b) => { if (a.dotted_order == null || b.dotted_order == null) return 0; return a.dotted_order.localeCompare(b.dotted_order); }); } return results; } export async function evaluateComparative(experiments, options) { if (experiments.length < 2) { throw new Error("Comparative evaluation requires at least 2 experiments."); } if (!options.evaluators.length) { throw new Error("At least one evaluator is required for comparative evaluation."); } if (options.maxConcurrency && options.maxConcurrency < 0) { throw new Error("maxConcurrency must be a positive number."); } const client = options.client ?? new Client(); const resolvedExperiments = await Promise.all(experiments); const projects = await (() => { if (!isExperimentResultsList(resolvedExperiments)) { return Promise.all(resolvedExperiments.map((experiment) => loadExperiment(client, experiment))); } // if we know the number of runs beforehand, check if the // number of runs in the project matches the expected number of runs return Promise.all(resolvedExperiments.map((experiment) => pRetry(async () => { const project = await loadExperiment(client, experiment); if (project.run_count !== experiment?.results.length) { throw new Error("Experiment is missing runs. Retrying."); } return project; }, { factor: 2, minTimeout: 1000, retries: 10 }))); })(); if (new Set(projects.map((p) => p.reference_dataset_id)).size > 1) { throw new Error("All experiments must have the same reference dataset."); } const referenceDatasetId = projects.at(0)?.reference_dataset_id; if (!referenceDatasetId) { throw new Error("Reference dataset is required for comparative evaluation."); } if (new Set(projects.map((p) => p.extra?.metadata?.dataset_version)).size > 1) { console.warn("Detected multiple dataset versions used by experiments, which may lead to inaccurate results."); } const datasetVersion = projects.at(0)?.extra?.metadata?.dataset_version; const id = uuid4(); const experimentName = (() => { if (!options.experimentPrefix) { const names = projects .map((p) => p.name) .filter(Boolean) .join(" vs. "); return `${names}-${uuid4().slice(0, 4)}`; } return `${options.experimentPrefix}-${uuid4().slice(0, 4)}`; })(); // TODO: add URL to the comparative experiment console.log(`Starting pairwise evaluation of: ${experimentName}`); const comparativeExperiment = await client.createComparativeExperiment({ id, name: experimentName, experimentIds: projects.map((p) => p.id), description: options.description, metadata: options.metadata, referenceDatasetId: projects.at(0)?.reference_dataset_id, }); const viewUrl = await (async () => { const projectId = projects.at(0)?.id ?? projects.at(1)?.id; const datasetId = comparativeExperiment?.reference_dataset_id; if (projectId && datasetId) { const hostUrl = (await client.getProjectUrl({ projectId })) .split("/projects/p/") .at(0); const result = new URL(`${hostUrl}/datasets/${datasetId}/compare`); result.searchParams.set("selectedSessions", projects.map((p) => p.id).join(",")); result.searchParams.set("comparativeExperiment", comparativeExperiment.id); return result.toString(); } return null; })(); if (viewUrl != null) { console.log(`View results at: ${viewUrl}`); } const experimentRuns = await Promise.all(projects.map((p) => loadTraces(client, p.id, { loadNested: !!options.loadNested }))); let exampleIdsIntersect; for (const runs of experimentRuns) { const exampleIdsSet = new Set(runs .map((r) => r.reference_example_id) .filter((x) => x != null)); if (!exampleIdsIntersect) { exampleIdsIntersect = exampleIdsSet; } else { exampleIdsIntersect = new Set([...exampleIdsIntersect].filter((x) => exampleIdsSet.has(x))); } } const exampleIds = [...(exampleIdsIntersect ?? [])]; if (!exampleIds.length) { throw new Error("No examples found in common between experiments."); } const exampleMap = {}; for (let start = 0; start < exampleIds.length; start += 99) { const exampleIdsChunk = exampleIds.slice(start, start + 99); for await (const example of client.listExamples({ datasetId: referenceDatasetId, exampleIds: exampleIdsChunk, asOf: datasetVersion, })) { exampleMap[example.id] = example; } } const runMapByExampleId = {}; for (const runs of experimentRuns) { for (const run of runs) { if (run.reference_example_id == null || !exampleIds.includes(run.reference_example_id)) { continue; } runMapByExampleId[run.reference_example_id] ??= []; runMapByExampleId[run.reference_example_id].push(run); } } const caller = new AsyncCaller({ maxConcurrency: options.maxConcurrency }); async function evaluateAndSubmitFeedback(runs, example, evaluator) { const expectedRunIds = new Set(runs.map((r) => r.id)); const result = await evaluator(options.randomizeOrder ? shuffle(runs) : runs, example); for (const [runId, score] of Object.entries(result.scores)) { // validate if the run id if (!expectedRunIds.has(runId)) { throw new Error(`Returning an invalid run id ${runId} from evaluator.`); } await client.createFeedback(runId, result.key, { score, sourceRunId: result.source_run_id, comparativeExperimentId: comparativeExperiment.id, }); } return result; } const tracedEvaluators = options.evaluators.map((evaluator) => traceable(async (runs, example) => { const evaluatorRun = getCurrentRunTree(); const result = await evaluator(runs, example); // sanitise the payload before sending to LangSmith evaluatorRun.inputs = { runs: runs, example: example }; evaluatorRun.outputs = result; return { ...result, source_run_id: result.source_run_id ?? evaluatorRun.id, }; }, { project_name: "evaluators", name: evaluator.name || "evaluator", })); const promises = Object.entries(runMapByExampleId).flatMap(([exampleId, runs]) => { const example = exampleMap[exampleId]; if (!example) throw new Error(`Example ${exampleId} not found.`); return tracedEvaluators.map((evaluator) => caller.call(evaluateAndSubmitFeedback, runs, exampleMap[exampleId], evaluator)); }); const results = await Promise.all(promises); return { experimentName, results }; }