447 lines
15 KiB
JavaScript
447 lines
15 KiB
JavaScript
|
import { BaseChatModel, } from "@langchain/core/language_models/chat_models";
|
||
|
import { AIMessage, ChatMessage, AIMessageChunk, } from "@langchain/core/messages";
|
||
|
import { ChatGenerationChunk } from "@langchain/core/outputs";
|
||
|
import { getEnvironmentVariable } from "@langchain/core/utils/env";
|
||
|
import { IterableReadableStream } from "@langchain/core/utils/stream";
|
||
|
/**
|
||
|
* Function that extracts the custom role of a generic chat message.
|
||
|
* @param message Chat message from which to extract the custom role.
|
||
|
* @returns The custom role of the chat message.
|
||
|
*/
|
||
|
function extractGenericMessageCustomRole(message) {
|
||
|
if (["system", "assistant", "user"].includes(message.role) === false) {
|
||
|
console.warn(`Unknown message role: ${message.role}`);
|
||
|
}
|
||
|
return message.role;
|
||
|
}
|
||
|
/**
|
||
|
* Function that converts a base message to a Tongyi message role.
|
||
|
* @param message Base message to convert.
|
||
|
* @returns The Tongyi message role.
|
||
|
*/
|
||
|
function messageToTongyiRole(message) {
|
||
|
const type = message._getType();
|
||
|
switch (type) {
|
||
|
case "ai":
|
||
|
return "assistant";
|
||
|
case "human":
|
||
|
return "user";
|
||
|
case "system":
|
||
|
return "system";
|
||
|
case "function":
|
||
|
throw new Error("Function messages not supported");
|
||
|
case "generic": {
|
||
|
if (!ChatMessage.isInstance(message))
|
||
|
throw new Error("Invalid generic chat message");
|
||
|
return extractGenericMessageCustomRole(message);
|
||
|
}
|
||
|
default:
|
||
|
throw new Error(`Unknown message type: ${type}`);
|
||
|
}
|
||
|
}
|
||
|
/**
|
||
|
* Wrapper around Ali Tongyi large language models that use the Chat endpoint.
|
||
|
*
|
||
|
* To use you should have the `ALIBABA_API_KEY`
|
||
|
* environment variable set.
|
||
|
*
|
||
|
* @augments BaseLLM
|
||
|
* @augments AlibabaTongyiInput
|
||
|
* @example
|
||
|
* ```typescript
|
||
|
* const qwen = new ChatAlibabaTongyi({
|
||
|
* alibabaApiKey: "YOUR-API-KEY",
|
||
|
* });
|
||
|
*
|
||
|
* const qwen = new ChatAlibabaTongyi({
|
||
|
* model: "qwen-turbo",
|
||
|
* temperature: 1,
|
||
|
* alibabaApiKey: "YOUR-API-KEY",
|
||
|
* });
|
||
|
*
|
||
|
* const messages = [new HumanMessage("Hello")];
|
||
|
*
|
||
|
* await qwen.call(messages);
|
||
|
* ```
|
||
|
*/
|
||
|
export class ChatAlibabaTongyi extends BaseChatModel {
|
||
|
static lc_name() {
|
||
|
return "ChatAlibabaTongyi";
|
||
|
}
|
||
|
get callKeys() {
|
||
|
return ["stop", "signal", "options"];
|
||
|
}
|
||
|
get lc_secrets() {
|
||
|
return {
|
||
|
alibabaApiKey: "ALIBABA_API_KEY",
|
||
|
};
|
||
|
}
|
||
|
get lc_aliases() {
|
||
|
return undefined;
|
||
|
}
|
||
|
constructor(fields = {}) {
|
||
|
super(fields);
|
||
|
Object.defineProperty(this, "lc_serializable", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "alibabaApiKey", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "streaming", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "prefixMessages", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "modelName", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "model", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "apiUrl", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "maxTokens", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "temperature", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "topP", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "topK", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "repetitionPenalty", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "seed", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
Object.defineProperty(this, "enableSearch", {
|
||
|
enumerable: true,
|
||
|
configurable: true,
|
||
|
writable: true,
|
||
|
value: void 0
|
||
|
});
|
||
|
this.alibabaApiKey =
|
||
|
fields?.alibabaApiKey ?? getEnvironmentVariable("ALIBABA_API_KEY");
|
||
|
if (!this.alibabaApiKey) {
|
||
|
throw new Error("Ali API key not found");
|
||
|
}
|
||
|
this.apiUrl =
|
||
|
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation";
|
||
|
this.lc_serializable = true;
|
||
|
this.streaming = fields.streaming ?? false;
|
||
|
this.prefixMessages = fields.prefixMessages ?? [];
|
||
|
this.temperature = fields.temperature;
|
||
|
this.topP = fields.topP;
|
||
|
this.topK = fields.topK;
|
||
|
this.seed = fields.seed;
|
||
|
this.maxTokens = fields.maxTokens;
|
||
|
this.repetitionPenalty = fields.repetitionPenalty;
|
||
|
this.enableSearch = fields.enableSearch;
|
||
|
this.modelName = fields?.model ?? fields.modelName ?? "qwen-turbo";
|
||
|
this.model = this.modelName;
|
||
|
}
|
||
|
/**
|
||
|
* Get the parameters used to invoke the model
|
||
|
*/
|
||
|
invocationParams() {
|
||
|
const parameters = {
|
||
|
stream: this.streaming,
|
||
|
temperature: this.temperature,
|
||
|
top_p: this.topP,
|
||
|
top_k: this.topK,
|
||
|
seed: this.seed,
|
||
|
max_tokens: this.maxTokens,
|
||
|
result_format: "text",
|
||
|
enable_search: this.enableSearch,
|
||
|
};
|
||
|
if (this.streaming) {
|
||
|
parameters.incremental_output = true;
|
||
|
}
|
||
|
else {
|
||
|
parameters.repetition_penalty = this.repetitionPenalty;
|
||
|
}
|
||
|
return parameters;
|
||
|
}
|
||
|
/**
|
||
|
* Get the identifying parameters for the model
|
||
|
*/
|
||
|
identifyingParams() {
|
||
|
return {
|
||
|
model: this.model,
|
||
|
...this.invocationParams(),
|
||
|
};
|
||
|
}
|
||
|
/** @ignore */
|
||
|
async _generate(messages, options, runManager) {
|
||
|
const parameters = this.invocationParams();
|
||
|
const messagesMapped = messages.map((message) => ({
|
||
|
role: messageToTongyiRole(message),
|
||
|
content: message.content,
|
||
|
}));
|
||
|
const data = parameters.stream
|
||
|
? await new Promise((resolve, reject) => {
|
||
|
let response;
|
||
|
let rejected = false;
|
||
|
let resolved = false;
|
||
|
this.completionWithRetry({
|
||
|
model: this.model,
|
||
|
parameters,
|
||
|
input: {
|
||
|
messages: messagesMapped,
|
||
|
},
|
||
|
}, true, options?.signal, (event) => {
|
||
|
const data = JSON.parse(event.data);
|
||
|
if (data?.code) {
|
||
|
if (rejected) {
|
||
|
return;
|
||
|
}
|
||
|
rejected = true;
|
||
|
reject(new Error(data?.message));
|
||
|
return;
|
||
|
}
|
||
|
const { text, finish_reason } = data.output;
|
||
|
if (!response) {
|
||
|
response = data;
|
||
|
}
|
||
|
else {
|
||
|
response.output.text += text;
|
||
|
response.output.finish_reason = finish_reason;
|
||
|
response.usage = data.usage;
|
||
|
}
|
||
|
void runManager?.handleLLMNewToken(text ?? "");
|
||
|
if (finish_reason && finish_reason !== "null") {
|
||
|
if (resolved || rejected) {
|
||
|
return;
|
||
|
}
|
||
|
resolved = true;
|
||
|
resolve(response);
|
||
|
}
|
||
|
}).catch((error) => {
|
||
|
if (!rejected) {
|
||
|
rejected = true;
|
||
|
reject(error);
|
||
|
}
|
||
|
});
|
||
|
})
|
||
|
: await this.completionWithRetry({
|
||
|
model: this.model,
|
||
|
parameters,
|
||
|
input: {
|
||
|
messages: messagesMapped,
|
||
|
},
|
||
|
}, false, options?.signal).then((data) => {
|
||
|
if (data?.code) {
|
||
|
throw new Error(data?.message);
|
||
|
}
|
||
|
return data;
|
||
|
});
|
||
|
const { input_tokens = 0, output_tokens = 0, total_tokens = 0, } = data.usage;
|
||
|
const { text } = data.output;
|
||
|
return {
|
||
|
generations: [
|
||
|
{
|
||
|
text,
|
||
|
message: new AIMessage(text),
|
||
|
},
|
||
|
],
|
||
|
llmOutput: {
|
||
|
tokenUsage: {
|
||
|
promptTokens: input_tokens,
|
||
|
completionTokens: output_tokens,
|
||
|
totalTokens: total_tokens,
|
||
|
},
|
||
|
},
|
||
|
};
|
||
|
}
|
||
|
/** @ignore */
|
||
|
async completionWithRetry(request, stream, signal, onmessage) {
|
||
|
const makeCompletionRequest = async () => {
|
||
|
const response = await fetch(this.apiUrl, {
|
||
|
method: "POST",
|
||
|
headers: {
|
||
|
...(stream ? { Accept: "text/event-stream" } : {}),
|
||
|
Authorization: `Bearer ${this.alibabaApiKey}`,
|
||
|
"Content-Type": "application/json",
|
||
|
},
|
||
|
body: JSON.stringify(request),
|
||
|
signal,
|
||
|
});
|
||
|
if (!stream) {
|
||
|
return response.json();
|
||
|
}
|
||
|
if (response.body) {
|
||
|
// response will not be a stream if an error occurred
|
||
|
if (!response.headers.get("content-type")?.startsWith("text/event-stream")) {
|
||
|
onmessage?.(new MessageEvent("message", {
|
||
|
data: await response.text(),
|
||
|
}));
|
||
|
return;
|
||
|
}
|
||
|
const reader = response.body.getReader();
|
||
|
const decoder = new TextDecoder("utf-8");
|
||
|
let data = "";
|
||
|
let continueReading = true;
|
||
|
while (continueReading) {
|
||
|
const { done, value } = await reader.read();
|
||
|
if (done) {
|
||
|
continueReading = false;
|
||
|
break;
|
||
|
}
|
||
|
data += decoder.decode(value);
|
||
|
let continueProcessing = true;
|
||
|
while (continueProcessing) {
|
||
|
const newlineIndex = data.indexOf("\n");
|
||
|
if (newlineIndex === -1) {
|
||
|
continueProcessing = false;
|
||
|
break;
|
||
|
}
|
||
|
const line = data.slice(0, newlineIndex);
|
||
|
data = data.slice(newlineIndex + 1);
|
||
|
if (line.startsWith("data:")) {
|
||
|
const event = new MessageEvent("message", {
|
||
|
data: line.slice("data:".length).trim(),
|
||
|
});
|
||
|
onmessage?.(event);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
return this.caller.call(makeCompletionRequest);
|
||
|
}
|
||
|
async *_streamResponseChunks(messages, options, runManager) {
|
||
|
const parameters = {
|
||
|
...this.invocationParams(),
|
||
|
stream: true,
|
||
|
incremental_output: true,
|
||
|
};
|
||
|
const messagesMapped = messages.map((message) => ({
|
||
|
role: messageToTongyiRole(message),
|
||
|
content: message.content,
|
||
|
}));
|
||
|
const stream = await this.caller.call(async () => this.createTongyiStream({
|
||
|
model: this.model,
|
||
|
parameters,
|
||
|
input: {
|
||
|
messages: messagesMapped,
|
||
|
},
|
||
|
}, options?.signal));
|
||
|
for await (const chunk of stream) {
|
||
|
const { text, finish_reason } = chunk.output;
|
||
|
yield new ChatGenerationChunk({
|
||
|
text,
|
||
|
message: new AIMessageChunk({ content: text }),
|
||
|
generationInfo: finish_reason === "stop"
|
||
|
? {
|
||
|
finish_reason,
|
||
|
request_id: chunk.request_id,
|
||
|
usage: chunk.usage,
|
||
|
}
|
||
|
: undefined,
|
||
|
});
|
||
|
await runManager?.handleLLMNewToken(text);
|
||
|
}
|
||
|
}
|
||
|
async *createTongyiStream(request, signal) {
|
||
|
const response = await fetch(this.apiUrl, {
|
||
|
method: "POST",
|
||
|
headers: {
|
||
|
Authorization: `Bearer ${this.alibabaApiKey}`,
|
||
|
Accept: "text/event-stream",
|
||
|
"Content-Type": "application/json",
|
||
|
},
|
||
|
body: JSON.stringify(request),
|
||
|
signal,
|
||
|
});
|
||
|
if (!response.ok) {
|
||
|
let error;
|
||
|
const responseText = await response.text();
|
||
|
try {
|
||
|
const json = JSON.parse(responseText);
|
||
|
error = new Error(`Tongyi call failed with status code ${response.status}: ${json.error}`);
|
||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||
|
}
|
||
|
catch (e) {
|
||
|
error = new Error(`Tongyi call failed with status code ${response.status}: ${responseText}`);
|
||
|
}
|
||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||
|
error.response = response;
|
||
|
throw error;
|
||
|
}
|
||
|
if (!response.body) {
|
||
|
throw new Error("Could not begin Tongyi stream. Please check the given URL and try again.");
|
||
|
}
|
||
|
const stream = IterableReadableStream.fromReadableStream(response.body);
|
||
|
const decoder = new TextDecoder();
|
||
|
let extra = "";
|
||
|
for await (const chunk of stream) {
|
||
|
const decoded = extra + decoder.decode(chunk);
|
||
|
const lines = decoded.split("\n");
|
||
|
extra = lines.pop() || "";
|
||
|
for (const line of lines) {
|
||
|
if (!line.startsWith("data:")) {
|
||
|
continue;
|
||
|
}
|
||
|
try {
|
||
|
yield JSON.parse(line.slice("data:".length).trim());
|
||
|
}
|
||
|
catch (e) {
|
||
|
console.warn(`Received a non-JSON parseable chunk: ${line}`);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
_llmType() {
|
||
|
return "alibaba_tongyi";
|
||
|
}
|
||
|
/** @ignore */
|
||
|
_combineLLMOutput() {
|
||
|
return [];
|
||
|
}
|
||
|
}
|