agsamantha/node_modules/langsmith/dist/traceable.js

482 lines
20 KiB
JavaScript
Raw Normal View History

2024-10-02 20:15:21 +00:00
import { AsyncLocalStorage } from "node:async_hooks";
import { RunTree, isRunTree, isRunnableConfigLike, } from "./run_trees.js";
import { isTracingEnabled } from "./env.js";
import { ROOT, AsyncLocalStorageProviderSingleton, } from "./singletons/traceable.js";
import { isKVMap, isReadableStream, isAsyncIterable, isIteratorLike, isThenable, isGenerator, isPromiseMethod, } from "./utils/asserts.js";
AsyncLocalStorageProviderSingleton.initializeGlobalInstance(new AsyncLocalStorage());
const handleRunInputs = (rawInputs) => {
const firstInput = rawInputs[0];
if (firstInput == null) {
return {};
}
if (rawInputs.length > 1) {
return { args: rawInputs };
}
if (isKVMap(firstInput)) {
return firstInput;
}
return { input: firstInput };
};
const handleRunOutputs = (rawOutputs) => {
if (isKVMap(rawOutputs)) {
return rawOutputs;
}
return { outputs: rawOutputs };
};
const getTracingRunTree = (runTree, inputs, getInvocationParams) => {
if (!isTracingEnabled(runTree.tracingEnabled)) {
return undefined;
}
runTree.inputs = handleRunInputs(inputs);
const invocationParams = getInvocationParams?.(...inputs);
if (invocationParams != null) {
runTree.extra ??= {};
runTree.extra.metadata = {
...invocationParams,
...runTree.extra.metadata,
};
}
return runTree;
};
// idea: store the state of the promise outside
// but only when the promise is "consumed"
const getSerializablePromise = (arg) => {
const proxyState = { current: undefined };
const promiseProxy = new Proxy(arg, {
get(target, prop, receiver) {
if (prop === "then") {
const boundThen = arg[prop].bind(arg);
return (resolve, reject = (x) => {
throw x;
}) => {
return boundThen((value) => {
proxyState.current = ["resolve", value];
return resolve(value);
}, (error) => {
proxyState.current = ["reject", error];
return reject(error);
});
};
}
if (prop === "catch") {
const boundCatch = arg[prop].bind(arg);
return (reject) => {
return boundCatch((error) => {
proxyState.current = ["reject", error];
return reject(error);
});
};
}
if (prop === "toJSON") {
return () => {
if (!proxyState.current)
return undefined;
const [type, value] = proxyState.current ?? [];
if (type === "resolve")
return value;
return { error: value };
};
}
return Reflect.get(target, prop, receiver);
},
});
return promiseProxy;
};
const convertSerializableArg = (arg) => {
if (isReadableStream(arg)) {
const proxyState = [];
const transform = new TransformStream({
start: () => void 0,
transform: (chunk, controller) => {
proxyState.push(chunk);
controller.enqueue(chunk);
},
flush: () => void 0,
});
const pipeThrough = arg.pipeThrough(transform);
Object.assign(pipeThrough, { toJSON: () => proxyState });
return pipeThrough;
}
if (isAsyncIterable(arg)) {
const proxyState = { current: [] };
return new Proxy(arg, {
get(target, prop, receiver) {
if (prop === Symbol.asyncIterator) {
return () => {
const boundIterator = arg[Symbol.asyncIterator].bind(arg);
const iterator = boundIterator();
return new Proxy(iterator, {
get(target, prop, receiver) {
if (prop === "next" || prop === "return" || prop === "throw") {
const bound = iterator.next.bind(iterator);
return (...args) => {
// @ts-expect-error TS cannot infer the argument types for the bound function
const wrapped = getSerializablePromise(bound(...args));
proxyState.current.push(wrapped);
return wrapped;
};
}
if (prop === "return" || prop === "throw") {
return iterator.next.bind(iterator);
}
return Reflect.get(target, prop, receiver);
},
});
};
}
if (prop === "toJSON") {
return () => {
const onlyNexts = proxyState.current;
const serialized = onlyNexts.map((next) => next.toJSON());
const chunks = serialized.reduce((memo, next) => {
if (next?.value)
memo.push(next.value);
return memo;
}, []);
return chunks;
};
}
return Reflect.get(target, prop, receiver);
},
});
}
if (!Array.isArray(arg) && isIteratorLike(arg)) {
const proxyState = [];
return new Proxy(arg, {
get(target, prop, receiver) {
if (prop === "next" || prop === "return" || prop === "throw") {
const bound = arg[prop]?.bind(arg);
return (...args) => {
// @ts-expect-error TS cannot infer the argument types for the bound function
const next = bound?.(...args);
if (next != null)
proxyState.push(next);
return next;
};
}
if (prop === "toJSON") {
return () => {
const chunks = proxyState.reduce((memo, next) => {
if (next.value)
memo.push(next.value);
return memo;
}, []);
return chunks;
};
}
return Reflect.get(target, prop, receiver);
},
});
}
if (isThenable(arg)) {
return getSerializablePromise(arg);
}
return arg;
};
/**
* Higher-order function that takes function as input and returns a
* "TraceableFunction" - a wrapped version of the input that
* automatically handles tracing. If the returned traceable function calls any
* traceable functions, those are automatically traced as well.
*
* The returned TraceableFunction can accept a run tree or run tree config as
* its first argument. If omitted, it will default to the caller's run tree,
* or will be treated as a root run.
*
* @param wrappedFunc Targeted function to be traced
* @param config Additional metadata such as name, tags or providing
* a custom LangSmith client instance
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function traceable(wrappedFunc, config) {
const { aggregator, argsConfigPath, __finalTracedIteratorKey, ...runTreeConfig } = config ?? {};
const traceableFunc = (...args) => {
let ensuredConfig;
try {
let runtimeConfig;
if (argsConfigPath) {
const [index, path] = argsConfigPath;
if (index === args.length - 1 && !path) {
runtimeConfig = args.pop();
}
else if (index <= args.length &&
typeof args[index] === "object" &&
args[index] !== null) {
if (path) {
const { [path]: extracted, ...rest } = args[index];
runtimeConfig = extracted;
args[index] = rest;
}
else {
runtimeConfig = args[index];
args.splice(index, 1);
}
}
}
ensuredConfig = {
name: wrappedFunc.name || "<lambda>",
...runTreeConfig,
...runtimeConfig,
tags: [
...new Set([
...(runTreeConfig?.tags ?? []),
...(runtimeConfig?.tags ?? []),
]),
],
metadata: {
...runTreeConfig?.metadata,
...runtimeConfig?.metadata,
},
};
}
catch (err) {
console.warn(`Failed to extract runtime config from args for ${runTreeConfig?.name ?? wrappedFunc.name}`, err);
ensuredConfig = {
name: wrappedFunc.name || "<lambda>",
...runTreeConfig,
};
}
const asyncLocalStorage = AsyncLocalStorageProviderSingleton.getInstance();
// TODO: deal with possible nested promises and async iterables
const processedArgs = args;
for (let i = 0; i < processedArgs.length; i++) {
processedArgs[i] = convertSerializableArg(processedArgs[i]);
}
const [currentRunTree, rawInputs] = (() => {
const [firstArg, ...restArgs] = processedArgs;
// used for handoff between LangChain.JS and traceable functions
if (isRunnableConfigLike(firstArg)) {
return [
getTracingRunTree(RunTree.fromRunnableConfig(firstArg, ensuredConfig), restArgs, config?.getInvocationParams),
restArgs,
];
}
// deprecated: legacy CallbackManagerRunTree used in runOnDataset
// override ALS and do not pass-through the run tree
if (isRunTree(firstArg) &&
"callbackManager" in firstArg &&
firstArg.callbackManager != null) {
return [firstArg, restArgs];
}
// when ALS is unreliable, users can manually
// pass in the run tree
if (firstArg === ROOT || isRunTree(firstArg)) {
const currentRunTree = getTracingRunTree(firstArg === ROOT
? new RunTree(ensuredConfig)
: firstArg.createChild(ensuredConfig), restArgs, config?.getInvocationParams);
return [currentRunTree, [currentRunTree, ...restArgs]];
}
// Node.JS uses AsyncLocalStorage (ALS) and AsyncResource
// to allow storing context
const prevRunFromStore = asyncLocalStorage.getStore();
if (prevRunFromStore) {
return [
getTracingRunTree(prevRunFromStore.createChild(ensuredConfig), processedArgs, config?.getInvocationParams),
processedArgs,
];
}
const currentRunTree = getTracingRunTree(new RunTree(ensuredConfig), processedArgs, config?.getInvocationParams);
return [currentRunTree, processedArgs];
})();
return asyncLocalStorage.run(currentRunTree, () => {
const postRunPromise = currentRunTree?.postRun();
async function handleChunks(chunks) {
if (aggregator !== undefined) {
try {
return await aggregator(chunks);
}
catch (e) {
console.error(`[ERROR]: LangSmith aggregation failed: `, e);
}
}
return chunks;
}
function tapReadableStreamForTracing(stream, snapshot) {
const reader = stream.getReader();
let finished = false;
const chunks = [];
const tappedStream = new ReadableStream({
async start(controller) {
// eslint-disable-next-line no-constant-condition
while (true) {
const result = await (snapshot
? snapshot(() => reader.read())
: reader.read());
if (result.done) {
finished = true;
await currentRunTree?.end(handleRunOutputs(await handleChunks(chunks)));
await handleEnd();
controller.close();
break;
}
chunks.push(result.value);
controller.enqueue(result.value);
}
},
async cancel(reason) {
if (!finished)
await currentRunTree?.end(undefined, "Cancelled");
await currentRunTree?.end(handleRunOutputs(await handleChunks(chunks)));
await handleEnd();
return reader.cancel(reason);
},
});
return tappedStream;
}
async function* wrapAsyncIteratorForTracing(iterator, snapshot) {
let finished = false;
const chunks = [];
try {
while (true) {
const { value, done } = await (snapshot
? snapshot(() => iterator.next())
: iterator.next());
if (done) {
finished = true;
break;
}
chunks.push(value);
yield value;
}
}
catch (e) {
await currentRunTree?.end(undefined, String(e));
throw e;
}
finally {
if (!finished)
await currentRunTree?.end(undefined, "Cancelled");
await currentRunTree?.end(handleRunOutputs(await handleChunks(chunks)));
await handleEnd();
}
}
function wrapAsyncGeneratorForTracing(iterable, snapshot) {
if (isReadableStream(iterable)) {
return tapReadableStreamForTracing(iterable, snapshot);
}
const iterator = iterable[Symbol.asyncIterator]();
const wrappedIterator = wrapAsyncIteratorForTracing(iterator, snapshot);
iterable[Symbol.asyncIterator] = () => wrappedIterator;
return iterable;
}
async function handleEnd() {
const onEnd = config?.on_end;
if (onEnd) {
if (!currentRunTree) {
console.warn("Can not call 'on_end' if currentRunTree is undefined");
}
else {
onEnd(currentRunTree);
}
}
await postRunPromise;
await currentRunTree?.patchRun();
}
function gatherAll(iterator) {
const chunks = [];
// eslint-disable-next-line no-constant-condition
while (true) {
const next = iterator.next();
chunks.push(next);
if (next.done)
break;
}
return chunks;
}
let returnValue;
try {
returnValue = wrappedFunc(...rawInputs);
}
catch (err) {
returnValue = Promise.reject(err);
}
if (isAsyncIterable(returnValue)) {
const snapshot = AsyncLocalStorage.snapshot();
return wrapAsyncGeneratorForTracing(returnValue, snapshot);
}
if (!Array.isArray(returnValue) &&
typeof returnValue === "object" &&
returnValue != null &&
__finalTracedIteratorKey !== undefined &&
isAsyncIterable(returnValue[__finalTracedIteratorKey])) {
const snapshot = AsyncLocalStorage.snapshot();
return {
...returnValue,
[__finalTracedIteratorKey]: wrapAsyncGeneratorForTracing(returnValue[__finalTracedIteratorKey], snapshot),
};
}
const tracedPromise = new Promise((resolve, reject) => {
Promise.resolve(returnValue)
.then(async (rawOutput) => {
if (isAsyncIterable(rawOutput)) {
const snapshot = AsyncLocalStorage.snapshot();
return resolve(wrapAsyncGeneratorForTracing(rawOutput, snapshot));
}
if (!Array.isArray(rawOutput) &&
typeof rawOutput === "object" &&
rawOutput != null &&
__finalTracedIteratorKey !== undefined &&
isAsyncIterable(rawOutput[__finalTracedIteratorKey])) {
const snapshot = AsyncLocalStorage.snapshot();
return {
...rawOutput,
[__finalTracedIteratorKey]: wrapAsyncGeneratorForTracing(rawOutput[__finalTracedIteratorKey], snapshot),
};
}
if (isGenerator(wrappedFunc) && isIteratorLike(rawOutput)) {
const chunks = gatherAll(rawOutput);
try {
await currentRunTree?.end(handleRunOutputs(await handleChunks(chunks.reduce((memo, { value, done }) => {
if (!done || typeof value !== "undefined") {
memo.push(value);
}
return memo;
}, []))));
await handleEnd();
}
catch (e) {
console.error("Error occurred during handleEnd:", e);
}
return (function* () {
for (const ret of chunks) {
if (ret.done)
return ret.value;
yield ret.value;
}
})();
}
try {
await currentRunTree?.end(handleRunOutputs(rawOutput));
await handleEnd();
}
finally {
// eslint-disable-next-line no-unsafe-finally
return rawOutput;
}
}, async (error) => {
await currentRunTree?.end(undefined, String(error));
await handleEnd();
throw error;
})
.then(resolve, reject);
});
if (typeof returnValue !== "object" || returnValue === null) {
return tracedPromise;
}
return new Proxy(returnValue, {
get(target, prop, receiver) {
if (isPromiseMethod(prop)) {
return tracedPromise[prop].bind(tracedPromise);
}
return Reflect.get(target, prop, receiver);
},
});
});
};
Object.defineProperty(traceableFunc, "langsmith:traceable", {
value: runTreeConfig,
});
return traceableFunc;
}
export { getCurrentRunTree, isTraceableFunction, withRunTree, ROOT, } from "./singletons/traceable.js";