diff options
Diffstat (limited to 'apps/workers/inference.ts')
| -rw-r--r-- | apps/workers/inference.ts | 155 |
1 files changed, 0 insertions, 155 deletions
diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts deleted file mode 100644 index fed9478f..00000000 --- a/apps/workers/inference.ts +++ /dev/null @@ -1,155 +0,0 @@ -import { Ollama } from "ollama"; -import OpenAI from "openai"; - -import serverConfig from "@hoarder/shared/config"; -import logger from "@hoarder/shared/logger"; - -export interface InferenceResponse { - response: string; - totalTokens: number | undefined; -} - -export interface InferenceClient { - inferFromText(prompt: string): Promise<InferenceResponse>; - inferFromImage( - prompt: string, - contentType: string, - image: string, - ): Promise<InferenceResponse>; -} - -export class InferenceClientFactory { - static build(): InferenceClient | null { - if (serverConfig.inference.openAIApiKey) { - return new OpenAIInferenceClient(); - } - - if (serverConfig.inference.ollamaBaseUrl) { - return new OllamaInferenceClient(); - } - return null; - } -} - -class OpenAIInferenceClient implements InferenceClient { - openAI: OpenAI; - - constructor() { - this.openAI = new OpenAI({ - apiKey: serverConfig.inference.openAIApiKey, - baseURL: serverConfig.inference.openAIBaseUrl, - }); - } - - async inferFromText(prompt: string): Promise<InferenceResponse> { - const chatCompletion = await this.openAI.chat.completions.create({ - messages: [{ role: "user", content: prompt }], - model: serverConfig.inference.textModel, - response_format: { type: "json_object" }, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; - } - - async inferFromImage( - prompt: string, - contentType: string, - image: string, - ): Promise<InferenceResponse> { - const chatCompletion = await this.openAI.chat.completions.create({ - model: serverConfig.inference.imageModel, - response_format: { type: "json_object" }, - messages: [ - { - role: "user", - content: [ - { type: "text", text: prompt }, - { - type: "image_url", - image_url: { - url: `data:${contentType};base64,${image}`, - detail: "low", - }, - }, - ], - }, - ], - max_tokens: 2000, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; - } -} - -class OllamaInferenceClient implements InferenceClient { - ollama: Ollama; - - constructor() { - this.ollama = new Ollama({ - host: serverConfig.inference.ollamaBaseUrl, - }); - } - - async runModel(model: string, prompt: string, image?: string) { - const chatCompletion = await this.ollama.chat({ - model: model, - format: "json", - stream: true, - keep_alive: serverConfig.inference.ollamaKeepAlive, - options: { - num_ctx: serverConfig.inference.contextLength, - }, - messages: [ - { role: "user", content: prompt, images: image ? [image] : undefined }, - ], - }); - - let totalTokens = 0; - let response = ""; - try { - for await (const part of chatCompletion) { - response += part.message.content; - if (!isNaN(part.eval_count)) { - totalTokens += part.eval_count; - } - if (!isNaN(part.prompt_eval_count)) { - totalTokens += part.prompt_eval_count; - } - } - } catch (e) { - // There seem to be some bug in ollama where you can get some successfull response, but still throw an error. - // Using stream + accumulating the response so far is a workaround. - // https://github.com/ollama/ollama-js/issues/72 - totalTokens = NaN; - logger.warn( - `Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`, - ); - } - - return { response, totalTokens }; - } - - async inferFromText(prompt: string): Promise<InferenceResponse> { - return await this.runModel(serverConfig.inference.textModel, prompt); - } - - async inferFromImage( - prompt: string, - _contentType: string, - image: string, - ): Promise<InferenceResponse> { - return await this.runModel( - serverConfig.inference.imageModel, - prompt, - image, - ); - } -} |
