agsamantha/node_modules/@langchain/community/dist/chat_models/deepinfra.js

137 lines
4.5 KiB
JavaScript
Raw Normal View History

2024-10-02 20:15:21 +00:00
import { BaseChatModel, } from "@langchain/core/language_models/chat_models";
import { AIMessage } from "@langchain/core/messages";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
export const DEFAULT_MODEL = "meta-llama/Meta-Llama-3-70B-Instruct";
export const API_BASE_URL = "https://api.deepinfra.com/v1/openai/chat/completions";
export const ENV_VARIABLE_API_KEY = "DEEPINFRA_API_TOKEN";
function messageToRole(message) {
const type = message._getType();
switch (type) {
case "ai":
return "assistant";
case "human":
return "user";
case "system":
return "system";
default:
throw new Error(`Unknown message type: ${type}`);
}
}
export class ChatDeepInfra extends BaseChatModel {
static lc_name() {
return "ChatDeepInfra";
}
get callKeys() {
return ["stop", "signal", "options"];
}
constructor(fields = {}) {
super(fields);
Object.defineProperty(this, "apiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "model", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "apiUrl", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "maxTokens", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "temperature", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
this.apiKey =
fields?.apiKey ?? getEnvironmentVariable(ENV_VARIABLE_API_KEY);
if (!this.apiKey) {
throw new Error("API key is required, set `DEEPINFRA_API_TOKEN` environment variable or pass it as a parameter");
}
this.apiUrl = API_BASE_URL;
this.model = fields.model ?? DEFAULT_MODEL;
this.temperature = fields.temperature ?? 0;
this.maxTokens = fields.maxTokens;
}
invocationParams() {
return {
model: this.model,
stream: false,
temperature: this.temperature,
max_tokens: this.maxTokens,
};
}
identifyingParams() {
return this.invocationParams();
}
async _generate(messages, options) {
const parameters = this.invocationParams();
const messagesMapped = messages.map((message) => ({
role: messageToRole(message),
content: message.content,
}));
const data = await this.completionWithRetry({ ...parameters, messages: messagesMapped }, false, options?.signal).then((data) => {
if (data?.code) {
throw new Error(data?.message);
}
const { finish_reason, message } = data.choices[0];
const text = message.content;
return {
...data,
output: { text, finish_reason },
};
});
const { prompt_tokens = 0, completion_tokens = 0, total_tokens = 0, } = data.usage ?? {};
const { text } = data.output;
return {
generations: [{ text, message: new AIMessage(text) }],
llmOutput: {
tokenUsage: {
promptTokens: prompt_tokens,
completionTokens: completion_tokens,
totalTokens: total_tokens,
},
},
};
}
async completionWithRetry(request, stream, signal) {
const body = {
temperature: this.temperature,
max_tokens: this.maxTokens,
...request,
model: this.model,
};
const makeCompletionRequest = async () => {
const response = await fetch(this.apiUrl, {
method: "POST",
headers: {
Authorization: `Bearer ${this.apiKey}`,
"Content-Type": "application/json",
},
body: JSON.stringify(body),
signal,
});
if (!stream) {
return response.json();
}
};
return this.caller.call(makeCompletionRequest);
}
_llmType() {
return "DeepInfra";
}
}