import { BaseChatModel, } from "@langchain/core/language_models/chat_models"; import { AIMessage, AIMessageChunk, ChatMessage, } from "@langchain/core/messages"; import { ChatGenerationChunk, } from "@langchain/core/outputs"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { convertEventStreamToIterableReadableDataStream } from "../utils/event_source_parse.js"; /** * Function that extracts the custom role of a generic chat message. * @param message Chat message from which to extract the custom role. * @returns The custom role of the chat message. */ function extractGenericMessageCustomRole(message) { if (message.role !== "assistant" && message.role !== "user") { console.warn(`Unknown message role: ${message.role}`); } return message.role; } /** * Function that converts a base message to a Wenxin message role. * @param message Base message to convert. * @returns The Wenxin message role. */ function messageToWenxinRole(message) { const type = message._getType(); switch (type) { case "ai": return "assistant"; case "human": return "user"; case "system": throw new Error("System messages should not be here"); case "function": throw new Error("Function messages not supported"); case "generic": { if (!ChatMessage.isInstance(message)) throw new Error("Invalid generic chat message"); return extractGenericMessageCustomRole(message); } default: throw new Error(`Unknown message type: ${type}`); } } /** * @deprecated Install and import from @langchain/baidu-qianfan instead. * Wrapper around Baidu ERNIE large language models that use the Chat endpoint. * * To use you should have the `BAIDU_API_KEY` and `BAIDU_SECRET_KEY` * environment variable set. * * @augments BaseLLM * @augments BaiduERNIEInput * @example * ```typescript * const ernieTurbo = new ChatBaiduWenxin({ * apiKey: "YOUR-API-KEY", * baiduSecretKey: "YOUR-SECRET-KEY", * }); * * const ernie = new ChatBaiduWenxin({ * model: "ERNIE-Bot", * temperature: 1, * apiKey: "YOUR-API-KEY", * baiduSecretKey: "YOUR-SECRET-KEY", * }); * * const messages = [new HumanMessage("Hello")]; * * let res = await ernieTurbo.call(messages); * * res = await ernie.call(messages); * ``` */ export class ChatBaiduWenxin extends BaseChatModel { static lc_name() { return "ChatBaiduWenxin"; } get callKeys() { return ["stop", "signal", "options"]; } get lc_secrets() { return { baiduApiKey: "BAIDU_API_KEY", apiKey: "BAIDU_API_KEY", baiduSecretKey: "BAIDU_SECRET_KEY", }; } get lc_aliases() { return undefined; } constructor(fields) { super(fields ?? {}); Object.defineProperty(this, "lc_serializable", { enumerable: true, configurable: true, writable: true, value: true }); Object.defineProperty(this, "baiduApiKey", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "apiKey", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "baiduSecretKey", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "accessToken", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "streaming", { enumerable: true, configurable: true, writable: true, value: false }); Object.defineProperty(this, "prefixMessages", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "userId", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "modelName", { enumerable: true, configurable: true, writable: true, value: "ERNIE-Bot-turbo" }); Object.defineProperty(this, "model", { enumerable: true, configurable: true, writable: true, value: "ERNIE-Bot-turbo" }); Object.defineProperty(this, "apiUrl", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "temperature", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "topP", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "penaltyScore", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.baiduApiKey = fields?.apiKey ?? fields?.baiduApiKey ?? getEnvironmentVariable("BAIDU_API_KEY"); if (!this.baiduApiKey) { throw new Error("Baidu API key not found"); } this.apiKey = this.baiduApiKey; this.baiduSecretKey = fields?.baiduSecretKey ?? getEnvironmentVariable("BAIDU_SECRET_KEY"); if (!this.baiduSecretKey) { throw new Error("Baidu Secret key not found"); } this.streaming = fields?.streaming ?? this.streaming; this.prefixMessages = fields?.prefixMessages ?? this.prefixMessages; this.userId = fields?.userId ?? this.userId; this.temperature = fields?.temperature ?? this.temperature; this.topP = fields?.topP ?? this.topP; this.penaltyScore = fields?.penaltyScore ?? this.penaltyScore; this.modelName = fields?.model ?? fields?.modelName ?? this.model; this.model = this.modelName; const models = { "ERNIE-Bot": "completions", "ERNIE-Bot-turbo": "eb-instant", "ERNIE-Bot-4": "completions_pro", "ERNIE-Speed-8K": "ernie_speed", "ERNIE-Speed-128K": "ernie-speed-128k", "ERNIE-4.0-8K": "completions_pro", "ERNIE-4.0-8K-Preview": "ernie-4.0-8k-preview", "ERNIE-3.5-8K": "completions", "ERNIE-3.5-8K-Preview": "ernie-3.5-8k-preview", "ERNIE-Lite-8K": "eb-instant", "ERNIE-Tiny-8K": "ernie-tiny-8k", "ERNIE-Character-8K": "ernie-char-8k", "ERNIE Speed-AppBuilder": "ai_apaas", }; if (this.model in models) { this.apiUrl = `https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/${models[this.model]}`; } else { throw new Error(`Invalid model name: ${this.model}`); } } /** * Method that retrieves the access token for making requests to the Baidu * API. * @param options Optional parsed call options. * @returns The access token for making requests to the Baidu API. */ async getAccessToken(options) { const url = `https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=${this.apiKey}&client_secret=${this.baiduSecretKey}`; const response = await fetch(url, { method: "POST", headers: { "Content-Type": "application/json", Accept: "application/json", }, signal: options?.signal, }); if (!response.ok) { const text = await response.text(); const error = new Error(`Baidu get access token failed with status code ${response.status}, response: ${text}`); // eslint-disable-next-line @typescript-eslint/no-explicit-any error.response = response; throw error; } const json = await response.json(); return json.access_token; } /** * Get the parameters used to invoke the model */ invocationParams() { return { stream: this.streaming, user_id: this.userId, temperature: this.temperature, top_p: this.topP, penalty_score: this.penaltyScore, }; } /** * Get the identifying parameters for the model */ identifyingParams() { return { model_name: this.model, ...this.invocationParams(), }; } _ensureMessages(messages) { return messages.map((message) => ({ role: messageToWenxinRole(message), content: message.text, })); } /** @ignore */ async _generate(messages, options, runManager) { const tokenUsage = {}; const params = this.invocationParams(); // Wenxin requires the system message to be put in the params, not messages array const systemMessage = messages.find((message) => message._getType() === "system"); if (systemMessage) { // eslint-disable-next-line no-param-reassign messages = messages.filter((message) => message !== systemMessage); params.system = systemMessage.text; } const messagesMapped = this._ensureMessages(messages); const data = params.stream ? await new Promise((resolve, reject) => { let response; let rejected = false; let resolved = false; this.completionWithRetry({ ...params, messages: messagesMapped, }, true, options?.signal, (event) => { const data = JSON.parse(event.data); if (data?.error_code) { if (rejected) { return; } rejected = true; reject(new Error(data?.error_msg)); return; } const message = data; // on the first message set the response properties if (!response) { response = { id: message.id, object: message.object, created: message.created, result: message.result, need_clear_history: message.need_clear_history, usage: message.usage, }; } else { response.result += message.result; response.created = message.created; response.need_clear_history = message.need_clear_history; response.usage = message.usage; } // TODO this should pass part.index to the callback // when that's supported there // eslint-disable-next-line no-void void runManager?.handleLLMNewToken(message.result ?? ""); if (message.is_end) { if (resolved || rejected) { return; } resolved = true; resolve(response); } }).catch((error) => { if (!rejected) { rejected = true; reject(error); } }); }) : await this.completionWithRetry({ ...params, messages: messagesMapped, }, false, options?.signal).then((data) => { if (data?.error_code) { throw new Error(data?.error_msg); } return data; }); const { completion_tokens: completionTokens, prompt_tokens: promptTokens, total_tokens: totalTokens, } = data.usage ?? {}; if (completionTokens) { tokenUsage.completionTokens = (tokenUsage.completionTokens ?? 0) + completionTokens; } if (promptTokens) { tokenUsage.promptTokens = (tokenUsage.promptTokens ?? 0) + promptTokens; } if (totalTokens) { tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens; } const generations = []; const text = data.result ?? ""; generations.push({ text, message: new AIMessage(text), }); return { generations, llmOutput: { tokenUsage }, }; } /** @ignore */ async completionWithRetry(request, stream, signal, onmessage) { // The first run will get the accessToken if (!this.accessToken) { this.accessToken = await this.getAccessToken(); } const findFirstNewlineIndex = (data) => { for (let i = 0; i < data.length;) { if (data[i] === 10) return i; if ((data[i] & 0b11100000) === 0b11000000) { i += 2; } else if ((data[i] & 0b11110000) === 0b11100000) { i += 3; } else if ((data[i] & 0b11111000) === 0b11110000) { i += 4; } else { i += 1; } } return -1; }; const makeCompletionRequest = async () => { const url = `${this.apiUrl}?access_token=${this.accessToken}`; const response = await fetch(url, { method: "POST", headers: { "Content-Type": "application/json", }, body: JSON.stringify(request), signal, }); if (!stream) { return response.json(); } else { if (response.body) { // response will not be a stream if an error occurred if (!response.headers .get("content-type") ?.startsWith("text/event-stream")) { onmessage?.(new MessageEvent("message", { data: await response.text(), })); return; } const reader = response.body.getReader(); const decoder = new TextDecoder("utf-8"); let dataArrayBuffer = new Uint8Array(0); let continueReading = true; while (continueReading) { const { done, value } = await reader.read(); if (done) { continueReading = false; break; } // merge the data first then decode in case of the Chinese characters are split between chunks const mergedArray = new Uint8Array(dataArrayBuffer.length + value.length); mergedArray.set(dataArrayBuffer); mergedArray.set(value, dataArrayBuffer.length); dataArrayBuffer = mergedArray; let continueProcessing = true; while (continueProcessing) { const newlineIndex = findFirstNewlineIndex(dataArrayBuffer); if (newlineIndex === -1) { continueProcessing = false; break; } const lineArrayBuffer = dataArrayBuffer.slice(0, findFirstNewlineIndex(dataArrayBuffer)); const line = decoder.decode(lineArrayBuffer); dataArrayBuffer = dataArrayBuffer.slice(findFirstNewlineIndex(dataArrayBuffer) + 1); if (line.startsWith("data:")) { const event = new MessageEvent("message", { data: line.slice("data:".length).trim(), }); onmessage?.(event); } } } } } }; return this.caller.call(makeCompletionRequest); } async getFullApiUrl() { if (!this.accessToken) { this.accessToken = await this.getAccessToken(); } return `${this.apiUrl}?access_token=${this.accessToken}`; } async createWenxinStream(request, signal) { const url = await this.getFullApiUrl(); const response = await fetch(url, { method: "POST", headers: { Accept: "text/event-stream", "Content-Type": "application/json", }, body: JSON.stringify(request), signal, }); if (!response.body) { throw new Error("Could not begin Wenxin stream. Please check the given URL and try again."); } return convertEventStreamToIterableReadableDataStream(response.body); } _deserialize(json) { try { return JSON.parse(json); } catch (e) { console.warn(`Received a non-JSON parseable chunk: ${json}`); } } async *_streamResponseChunks(messages, options, runManager) { const parameters = { ...this.invocationParams(), stream: true, }; // Wenxin requires the system message to be put in the params, not messages array const systemMessage = messages.find((message) => message._getType() === "system"); if (systemMessage) { // eslint-disable-next-line no-param-reassign messages = messages.filter((message) => message !== systemMessage); parameters.system = systemMessage.text; } const messagesMapped = this._ensureMessages(messages); const stream = await this.caller.call(async () => this.createWenxinStream({ ...parameters, messages: messagesMapped, }, options?.signal)); for await (const chunk of stream) { const deserializedChunk = this._deserialize(chunk); const { result, is_end, id } = deserializedChunk; yield new ChatGenerationChunk({ text: result, message: new AIMessageChunk({ content: result }), generationInfo: is_end ? { is_end, request_id: id, usage: chunk.usage, } : undefined, }); await runManager?.handleLLMNewToken(result); } } _llmType() { return "baiduwenxin"; } /** @ignore */ _combineLLMOutput() { return []; } }