171 lines
5.9 KiB
JavaScript
171 lines
5.9 KiB
JavaScript
import { LLM } from "@langchain/core/language_models/llms";
|
|
import { GenerationChunk } from "@langchain/core/outputs";
|
|
import { getEnvironmentVariable } from "@langchain/core/utils/env";
|
|
/**
|
|
* Class implementing the Large Language Model (LLM) interface using the
|
|
* Hugging Face Inference API for text generation.
|
|
* @example
|
|
* ```typescript
|
|
* const model = new HuggingFaceInference({
|
|
* model: "gpt2",
|
|
* temperature: 0.7,
|
|
* maxTokens: 50,
|
|
* });
|
|
*
|
|
* const res = await model.invoke(
|
|
* "Question: What would be a good company name for a company that makes colorful socks?\nAnswer:"
|
|
* );
|
|
* console.log({ res });
|
|
* ```
|
|
*/
|
|
export class HuggingFaceInference extends LLM {
|
|
get lc_secrets() {
|
|
return {
|
|
apiKey: "HUGGINGFACEHUB_API_KEY",
|
|
};
|
|
}
|
|
constructor(fields) {
|
|
super(fields ?? {});
|
|
Object.defineProperty(this, "lc_serializable", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: true
|
|
});
|
|
Object.defineProperty(this, "model", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: "gpt2"
|
|
});
|
|
Object.defineProperty(this, "temperature", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "maxTokens", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "stopSequences", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "topP", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "topK", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "frequencyPenalty", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "apiKey", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "endpointUrl", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
Object.defineProperty(this, "includeCredentials", {
|
|
enumerable: true,
|
|
configurable: true,
|
|
writable: true,
|
|
value: undefined
|
|
});
|
|
this.model = fields?.model ?? this.model;
|
|
this.temperature = fields?.temperature ?? this.temperature;
|
|
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
|
|
this.stopSequences = fields?.stopSequences ?? this.stopSequences;
|
|
this.topP = fields?.topP ?? this.topP;
|
|
this.topK = fields?.topK ?? this.topK;
|
|
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty;
|
|
this.apiKey =
|
|
fields?.apiKey ?? getEnvironmentVariable("HUGGINGFACEHUB_API_KEY");
|
|
this.endpointUrl = fields?.endpointUrl;
|
|
this.includeCredentials = fields?.includeCredentials;
|
|
if (!this.apiKey) {
|
|
throw new Error(`Please set an API key for HuggingFace Hub in the environment variable "HUGGINGFACEHUB_API_KEY" or in the apiKey field of the HuggingFaceInference constructor.`);
|
|
}
|
|
}
|
|
_llmType() {
|
|
return "hf";
|
|
}
|
|
invocationParams(options) {
|
|
return {
|
|
model: this.model,
|
|
parameters: {
|
|
// make it behave similar to openai, returning only the generated text
|
|
return_full_text: false,
|
|
temperature: this.temperature,
|
|
max_new_tokens: this.maxTokens,
|
|
stop: options?.stop ?? this.stopSequences,
|
|
top_p: this.topP,
|
|
top_k: this.topK,
|
|
repetition_penalty: this.frequencyPenalty,
|
|
},
|
|
};
|
|
}
|
|
async *_streamResponseChunks(prompt, options, runManager) {
|
|
const hfi = await this._prepareHFInference();
|
|
const stream = await this.caller.call(async () => hfi.textGenerationStream({
|
|
...this.invocationParams(options),
|
|
inputs: prompt,
|
|
}));
|
|
for await (const chunk of stream) {
|
|
const token = chunk.token.text;
|
|
yield new GenerationChunk({ text: token, generationInfo: chunk });
|
|
await runManager?.handleLLMNewToken(token ?? "");
|
|
// stream is done
|
|
if (chunk.generated_text)
|
|
yield new GenerationChunk({
|
|
text: "",
|
|
generationInfo: { finished: true },
|
|
});
|
|
}
|
|
}
|
|
/** @ignore */
|
|
async _call(prompt, options) {
|
|
const hfi = await this._prepareHFInference();
|
|
const args = { ...this.invocationParams(options), inputs: prompt };
|
|
const res = await this.caller.callWithOptions({ signal: options.signal }, hfi.textGeneration.bind(hfi), args);
|
|
return res.generated_text;
|
|
}
|
|
/** @ignore */
|
|
async _prepareHFInference() {
|
|
const { HfInference } = await HuggingFaceInference.imports();
|
|
const hfi = new HfInference(this.apiKey, {
|
|
includeCredentials: this.includeCredentials,
|
|
});
|
|
return this.endpointUrl ? hfi.endpoint(this.endpointUrl) : hfi;
|
|
}
|
|
/** @ignore */
|
|
static async imports() {
|
|
try {
|
|
const { HfInference } = await import("@huggingface/inference");
|
|
return { HfInference };
|
|
}
|
|
catch (e) {
|
|
throw new Error("Please install huggingface as a dependency with, e.g. `yarn add @huggingface/inference`");
|
|
}
|
|
}
|
|
}
|