256 lines
9.1 KiB
JavaScript
256 lines
9.1 KiB
JavaScript
|
import { BaseChatModel, } from "@langchain/core/language_models/chat_models";
|
||
|
import { AIMessage, ChatMessage, HumanMessage, SystemMessage, HumanMessageChunk, AIMessageChunk, SystemMessageChunk, ChatMessageChunk, } from "@langchain/core/messages";
|
||
|
import { ChatGenerationChunk, } from "@langchain/core/outputs";
|
||
|
import { getEnvironmentVariable } from "@langchain/core/utils/env";
|
||
|
import { convertEventStreamToIterableReadableDataStream } from "../utils/event_source_parse.js";
|
||
|
function messageToFriendliRole(message) {
|
||
|
const type = message._getType();
|
||
|
switch (type) {
|
||
|
case "ai":
|
||
|
return "assistant";
|
||
|
case "human":
|
||
|
return "user";
|
||
|
case "system":
|
||
|
return "system";
|
||
|
case "function":
|
||
|
throw new Error("Function messages not supported");
|
||
|
case "generic": {
|
||
|
if (!ChatMessage.isInstance(message)) {
|
||
|
throw new Error("Invalid generic chat message");
|
||
|
}
|
||
|
if (["system", "assistant", "user"].includes(message.role)) {
|
||
|
return message.role;
|
||
|
}
|
||
|
throw new Error(`Unknown message type: ${type}`);
|
||
|
}
|
||
|
default:
|
||
|
throw new Error(`Unknown message type: ${type}`);
|
||
|
}
|
||
|
}
|
||
|
function friendliResponseToChatMessage(message) {
|
||
|
switch (message.role) {
|
||
|
case "user":
|
||
|
return new HumanMessage(message.content ?? "");
|
||
|
case "assistant":
|
||
|
return new AIMessage(message.content ?? "");
|
||
|
case "system":
|
||
|
return new SystemMessage(message.content ?? "");
|
||
|
default:
|
||
|
return new ChatMessage(message.content ?? "", message.role ?? "unknown");
|
||
|
}
|
||
|
}
|
||
|
function _convertDeltaToMessageChunk(
|
||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||
|
delta) {
|
||
|
const role = delta.role ?? "assistant";
|
||
|
const content = delta.content ?? "";
|
||
|
let additional_kwargs;
|
||
|
if (delta.function_call) {
|
||
|
additional_kwargs = {
|
||
|
function_call: delta.function_call,
|
||
|
};
|
||
|
}
|
||
|
else {
|
||
|
additional_kwargs = {};
|
||
|
}
|
||
|
if (role === "user") {
|
||
|
return new HumanMessageChunk({ content });
|
||
|
}
|
||
|
else if (role === "assistant") {
|
||
|
return new AIMessageChunk({ content, additional_kwargs });
|
||
|
}
|
||
|
else if (role === "system") {
|
||
|
return new SystemMessageChunk({ content });
|
||
|
}
|
||
|
else {
|
||
|
return new ChatMessageChunk({ content, role });
|
||
|
}
|
||
|
}
|
||
|
/**
|
||
|
* The ChatFriendli class is used to interact with Friendli inference Endpoint models.
|
||
|
* This requires your Friendli Token and Friendli Team which is autoloaded if not specified.
|
||
|
*/
|
||
|
export class ChatFriendli extends BaseChatModel {
|
||
|
static lc_name() {
|
||
|
return "Friendli";
|
||
|
}
|
||
|
get lc_secrets() {
|
||
|
return {
|
||
|
friendliToken: "FRIENDLI_TOKEN",
|
||
|
friendliTeam: "FRIENDLI_TEAM",
|
||
|
};
|
||
|
}
|
||
|
constructor(fields) {
|
||
|
super(fields);
|
||
|
Object.defineProperty(this, "lc_serializable", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: true
|
||
|
});
|
||
|
Object.defineProperty(this, "model", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "meta-llama-3-8b-instruct"
|
||
|
});
|
||
|
Object.defineProperty(this, "baseUrl", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: "https://inference.friendli.ai"
|
||
|
});
|
||
|
Object.defineProperty(this, "friendliToken", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "friendliTeam", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "frequencyPenalty", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "maxTokens", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "stop", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "temperature", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "topP", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "modelKwargs", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
this.model = fields?.model ?? this.model;
|
||
|
this.baseUrl = fields?.baseUrl ?? this.baseUrl;
|
||
|
this.friendliToken =
|
||
|
fields?.friendliToken ?? getEnvironmentVariable("FRIENDLI_TOKEN");
|
||
|
this.friendliTeam =
|
||
|
fields?.friendliTeam ?? getEnvironmentVariable("FRIENDLI_TEAM");
|
||
|
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
|
||
|
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
|
||
|
this.stop = fields?.stop ?? this.stop;
|
||
|
this.temperature = fields?.temperature ?? this.temperature;
|
||
|
this.topP = fields?.topP ?? this.topP;
|
||
|
this.modelKwargs = fields?.modelKwargs ?? {};
|
||
|
if (!this.friendliToken) {
|
||
|
throw new Error("Missing Friendli Token");
|
||
|
}
|
||
|
}
|
||
|
_llmType() {
|
||
|
return "friendli";
|
||
|
}
|
||
|
constructHeaders(stream) {
|
||
|
return {
|
||
|
"Content-Type": "application/json",
|
||
|
Accept: stream ? "text/event-stream" : "application/json",
|
||
|
Authorization: `Bearer ${this.friendliToken}`,
|
||
|
"X-Friendli-Team": this.friendliTeam ?? "",
|
||
|
};
|
||
|
}
|
||
|
constructBody(messages, stream, _options) {
|
||
|
const messageList = messages.map((message) => {
|
||
|
if (typeof message.content !== "string") {
|
||
|
throw new Error("Friendli does not support non-string message content.");
|
||
|
}
|
||
|
return {
|
||
|
role: messageToFriendliRole(message),
|
||
|
content: message.content,
|
||
|
};
|
||
|
});
|
||
|
const body = JSON.stringify({
|
||
|
messages: messageList,
|
||
|
stream,
|
||
|
model: this.model,
|
||
|
max_tokens: this.maxTokens,
|
||
|
frequency_penalty: this.frequencyPenalty,
|
||
|
stop: this.stop,
|
||
|
temperature: this.temperature,
|
||
|
top_p: this.topP,
|
||
|
...this.modelKwargs,
|
||
|
});
|
||
|
return body;
|
||
|
}
|
||
|
/**
|
||
|
* Calls the Friendli endpoint and retrieves the result.
|
||
|
* @param {BaseMessage[]} messages The input messages.
|
||
|
* @returns {Promise<ChatResult>} A promise that resolves to the generated chat result.
|
||
|
*/
|
||
|
/** @ignore */
|
||
|
async _generate(messages, _options) {
|
||
|
const response = (await this.caller.call(async () => fetch(`${this.baseUrl}/v1/chat/completions`, {
|
||
|
method: "POST",
|
||
|
headers: this.constructHeaders(false),
|
||
|
body: this.constructBody(messages, false, _options),
|
||
|
}).then((res) => res.json())));
|
||
|
const generations = [];
|
||
|
for (const data of response.choices ?? []) {
|
||
|
const text = data.message?.content ?? "";
|
||
|
const generation = {
|
||
|
text,
|
||
|
message: friendliResponseToChatMessage(data.message ?? {}),
|
||
|
};
|
||
|
if (data.finish_reason) {
|
||
|
generation.generationInfo = { finish_reason: data.finish_reason };
|
||
|
}
|
||
|
generations.push(generation);
|
||
|
}
|
||
|
return { generations };
|
||
|
}
|
||
|
async *_streamResponseChunks(messages, _options, runManager) {
|
||
|
const response = await this.caller.call(async () => fetch(`${this.baseUrl}/v1/chat/completions`, {
|
||
|
method: "POST",
|
||
|
headers: this.constructHeaders(true),
|
||
|
body: this.constructBody(messages, true, _options),
|
||
|
}));
|
||
|
if (response.status !== 200 ?? !response.body) {
|
||
|
const errorResponse = await response.json();
|
||
|
throw new Error(JSON.stringify(errorResponse));
|
||
|
}
|
||
|
const stream = convertEventStreamToIterableReadableDataStream(response.body);
|
||
|
for await (const chunk of stream) {
|
||
|
if (chunk === "[DONE]")
|
||
|
break;
|
||
|
const parsedChunk = JSON.parse(chunk);
|
||
|
if (parsedChunk.choices[0].finish_reason === null) {
|
||
|
const generationChunk = new ChatGenerationChunk({
|
||
|
message: _convertDeltaToMessageChunk(parsedChunk.choices[0].delta),
|
||
|
text: parsedChunk.choices[0].delta.content ?? "",
|
||
|
generationInfo: {
|
||
|
finishReason: parsedChunk.choices[0].finish_reason,
|
||
|
},
|
||
|
});
|
||
|
yield generationChunk;
|
||
|
void runManager?.handleLLMNewToken(generationChunk.text ?? "");
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|