import { BaseChatModel, } from "@langchain/core/language_models/chat_models"; import { AIMessage, ChatMessage, HumanMessage, SystemMessage, HumanMessageChunk, AIMessageChunk, SystemMessageChunk, ChatMessageChunk, } from "@langchain/core/messages"; import { ChatGenerationChunk, } from "@langchain/core/outputs"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { convertEventStreamToIterableReadableDataStream } from "../utils/event_source_parse.js"; function messageToFriendliRole(message) { const type = message._getType(); switch (type) { case "ai": return "assistant"; case "human": return "user"; case "system": return "system"; case "function": throw new Error("Function messages not supported"); case "generic": { if (!ChatMessage.isInstance(message)) { throw new Error("Invalid generic chat message"); } if (["system", "assistant", "user"].includes(message.role)) { return message.role; } throw new Error(`Unknown message type: ${type}`); } default: throw new Error(`Unknown message type: ${type}`); } } function friendliResponseToChatMessage(message) { switch (message.role) { case "user": return new HumanMessage(message.content ?? ""); case "assistant": return new AIMessage(message.content ?? ""); case "system": return new SystemMessage(message.content ?? ""); default: return new ChatMessage(message.content ?? "", message.role ?? "unknown"); } } function _convertDeltaToMessageChunk( // eslint-disable-next-line @typescript-eslint/no-explicit-any delta) { const role = delta.role ?? "assistant"; const content = delta.content ?? ""; let additional_kwargs; if (delta.function_call) { additional_kwargs = { function_call: delta.function_call, }; } else { additional_kwargs = {}; } if (role === "user") { return new HumanMessageChunk({ content }); } else if (role === "assistant") { return new AIMessageChunk({ content, additional_kwargs }); } else if (role === "system") { return new SystemMessageChunk({ content }); } else { return new ChatMessageChunk({ content, role }); } } /** * The ChatFriendli class is used to interact with Friendli inference Endpoint models. * This requires your Friendli Token and Friendli Team which is autoloaded if not specified. */ export class ChatFriendli extends BaseChatModel { static lc_name() { return "Friendli"; } get lc_secrets() { return { friendliToken: "FRIENDLI_TOKEN", friendliTeam: "FRIENDLI_TEAM", }; } constructor(fields) { super(fields); Object.defineProperty(this, "lc_serializable", { enumerable: true, configurable: true, writable: true, value: true }); Object.defineProperty(this, "model", { enumerable: true, configurable: true, writable: true, value: "meta-llama-3-8b-instruct" }); Object.defineProperty(this, "baseUrl", { enumerable: true, configurable: true, writable: true, value: "https://inference.friendli.ai" }); Object.defineProperty(this, "friendliToken", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "friendliTeam", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "frequencyPenalty", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "maxTokens", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "stop", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "temperature", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "topP", { enumerable: true, configurable: true, writable: true, value: void 0 }); Object.defineProperty(this, "modelKwargs", { enumerable: true, configurable: true, writable: true, value: void 0 }); this.model = fields?.model ?? this.model; this.baseUrl = fields?.baseUrl ?? this.baseUrl; this.friendliToken = fields?.friendliToken ?? getEnvironmentVariable("FRIENDLI_TOKEN"); this.friendliTeam = fields?.friendliTeam ?? getEnvironmentVariable("FRIENDLI_TEAM"); this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; this.maxTokens = fields?.maxTokens ?? this.maxTokens; this.stop = fields?.stop ?? this.stop; this.temperature = fields?.temperature ?? this.temperature; this.topP = fields?.topP ?? this.topP; this.modelKwargs = fields?.modelKwargs ?? {}; if (!this.friendliToken) { throw new Error("Missing Friendli Token"); } } _llmType() { return "friendli"; } constructHeaders(stream) { return { "Content-Type": "application/json", Accept: stream ? "text/event-stream" : "application/json", Authorization: `Bearer ${this.friendliToken}`, "X-Friendli-Team": this.friendliTeam ?? "", }; } constructBody(messages, stream, _options) { const messageList = messages.map((message) => { if (typeof message.content !== "string") { throw new Error("Friendli does not support non-string message content."); } return { role: messageToFriendliRole(message), content: message.content, }; }); const body = JSON.stringify({ messages: messageList, stream, model: this.model, max_tokens: this.maxTokens, frequency_penalty: this.frequencyPenalty, stop: this.stop, temperature: this.temperature, top_p: this.topP, ...this.modelKwargs, }); return body; } /** * Calls the Friendli endpoint and retrieves the result. * @param {BaseMessage[]} messages The input messages. * @returns {Promise} A promise that resolves to the generated chat result. */ /** @ignore */ async _generate(messages, _options) { const response = (await this.caller.call(async () => fetch(`${this.baseUrl}/v1/chat/completions`, { method: "POST", headers: this.constructHeaders(false), body: this.constructBody(messages, false, _options), }).then((res) => res.json()))); const generations = []; for (const data of response.choices ?? []) { const text = data.message?.content ?? ""; const generation = { text, message: friendliResponseToChatMessage(data.message ?? {}), }; if (data.finish_reason) { generation.generationInfo = { finish_reason: data.finish_reason }; } generations.push(generation); } return { generations }; } async *_streamResponseChunks(messages, _options, runManager) { const response = await this.caller.call(async () => fetch(`${this.baseUrl}/v1/chat/completions`, { method: "POST", headers: this.constructHeaders(true), body: this.constructBody(messages, true, _options), })); if (response.status !== 200 ?? !response.body) { const errorResponse = await response.json(); throw new Error(JSON.stringify(errorResponse)); } const stream = convertEventStreamToIterableReadableDataStream(response.body); for await (const chunk of stream) { if (chunk === "[DONE]") break; const parsedChunk = JSON.parse(chunk); if (parsedChunk.choices[0].finish_reason === null) { const generationChunk = new ChatGenerationChunk({ message: _convertDeltaToMessageChunk(parsedChunk.choices[0].delta), text: parsedChunk.choices[0].delta.content ?? "", generationInfo: { finishReason: parsedChunk.choices[0].finish_reason, }, }); yield generationChunk; void runManager?.handleLLMNewToken(generationChunk.text ?? ""); } } } }