agsamantha/node_modules/@langchain/community/dist/llms/togetherai.js
2024-10-02 15:15:21 -05:00

188 lines
6.8 KiB
JavaScript

import { LLM, } from "@langchain/core/language_models/llms";
import { GenerationChunk } from "@langchain/core/outputs";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { convertEventStreamToIterableReadableDataStream } from "../utils/event_source_parse.js";
export class TogetherAI extends LLM {
static lc_name() {
return "TogetherAI";
}
constructor(inputs) {
super(inputs);
Object.defineProperty(this, "lc_serializable", {
enumerable: true,
configurable: true,
writable: true,
value: true
});
Object.defineProperty(this, "temperature", {
enumerable: true,
configurable: true,
writable: true,
value: 0.7
});
Object.defineProperty(this, "topP", {
enumerable: true,
configurable: true,
writable: true,
value: 0.7
});
Object.defineProperty(this, "topK", {
enumerable: true,
configurable: true,
writable: true,
value: 50
});
Object.defineProperty(this, "modelName", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "model", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "streaming", {
enumerable: true,
configurable: true,
writable: true,
value: false
});
Object.defineProperty(this, "repetitionPenalty", {
enumerable: true,
configurable: true,
writable: true,
value: 1
});
Object.defineProperty(this, "logprobs", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "maxTokens", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "safetyModel", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "stop", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "apiKey", {
enumerable: true,
configurable: true,
writable: true,
value: void 0
});
Object.defineProperty(this, "inferenceAPIUrl", {
enumerable: true,
configurable: true,
writable: true,
value: "https://api.together.xyz/inference"
});
const apiKey = inputs.apiKey ?? getEnvironmentVariable("TOGETHER_AI_API_KEY");
if (!apiKey) {
throw new Error("TOGETHER_AI_API_KEY not found.");
}
if (!inputs.model && !inputs.modelName) {
throw new Error("Model name is required for TogetherAI.");
}
this.apiKey = apiKey;
this.temperature = inputs?.temperature ?? this.temperature;
this.topK = inputs?.topK ?? this.topK;
this.topP = inputs?.topP ?? this.topP;
this.modelName = inputs.model ?? inputs.modelName ?? "";
this.model = this.modelName;
this.streaming = inputs.streaming ?? this.streaming;
this.repetitionPenalty = inputs.repetitionPenalty ?? this.repetitionPenalty;
this.logprobs = inputs.logprobs;
this.safetyModel = inputs.safetyModel;
this.maxTokens = inputs.maxTokens;
this.stop = inputs.stop;
}
_llmType() {
return "together_ai";
}
constructHeaders() {
return {
accept: "application/json",
"content-type": "application/json",
Authorization: `Bearer ${this.apiKey}`,
};
}
constructBody(prompt, options) {
const body = {
model: options?.model ?? options?.modelName ?? this?.model,
prompt,
temperature: this?.temperature ?? options?.temperature,
top_k: this?.topK ?? options?.topK,
top_p: this?.topP ?? options?.topP,
repetition_penalty: this?.repetitionPenalty ?? options?.repetitionPenalty,
logprobs: this?.logprobs ?? options?.logprobs,
stream_tokens: this?.streaming,
safety_model: this?.safetyModel ?? options?.safetyModel,
max_tokens: this?.maxTokens ?? options?.maxTokens,
stop: this?.stop ?? options?.stop,
};
return body;
}
async completionWithRetry(prompt, options) {
return this.caller.call(async () => {
const fetchResponse = await fetch(this.inferenceAPIUrl, {
method: "POST",
headers: {
...this.constructHeaders(),
},
body: JSON.stringify(this.constructBody(prompt, options)),
});
if (fetchResponse.status === 200) {
return fetchResponse.json();
}
const errorResponse = await fetchResponse.json();
throw new Error(`Error getting prompt completion from Together AI. ${JSON.stringify(errorResponse, null, 2)}`);
});
}
/** @ignore */
async _call(prompt, options) {
const response = await this.completionWithRetry(prompt, options);
const outputText = response.output.choices[0].text;
return outputText ?? "";
}
async *_streamResponseChunks(prompt, options, runManager) {
const fetchResponse = await fetch(this.inferenceAPIUrl, {
method: "POST",
headers: {
...this.constructHeaders(),
},
body: JSON.stringify(this.constructBody(prompt, options)),
});
if (fetchResponse.status !== 200 ?? !fetchResponse.body) {
const errorResponse = await fetchResponse.json();
throw new Error(`Error getting prompt completion from Together AI. ${JSON.stringify(errorResponse, null, 2)}`);
}
const stream = convertEventStreamToIterableReadableDataStream(fetchResponse.body);
for await (const chunk of stream) {
if (chunk !== "[DONE]") {
const parsedChunk = JSON.parse(chunk);
const generationChunk = new GenerationChunk({
text: parsedChunk.choices[0].text ?? "",
});
yield generationChunk;
// eslint-disable-next-line no-void
void runManager?.handleLLMNewToken(generationChunk.text ?? "");
}
}
}
}