From 3e727f7ba3ad157ca1ccc6100711266cae1bde23 Mon Sep 17 00:00:00 2001 From: Mohamed Bassem Date: Sat, 26 Oct 2024 20:07:16 +0000 Subject: refactor: Move inference to the shared package --- apps/workers/inference.ts | 155 ------------------------------------------- apps/workers/openaiWorker.ts | 4 +- apps/workers/package.json | 2 - 3 files changed, 2 insertions(+), 159 deletions(-) delete mode 100644 apps/workers/inference.ts (limited to 'apps') diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts deleted file mode 100644 index fed9478f..00000000 --- a/apps/workers/inference.ts +++ /dev/null @@ -1,155 +0,0 @@ -import { Ollama } from "ollama"; -import OpenAI from "openai"; - -import serverConfig from "@hoarder/shared/config"; -import logger from "@hoarder/shared/logger"; - -export interface InferenceResponse { - response: string; - totalTokens: number | undefined; -} - -export interface InferenceClient { - inferFromText(prompt: string): Promise; - inferFromImage( - prompt: string, - contentType: string, - image: string, - ): Promise; -} - -export class InferenceClientFactory { - static build(): InferenceClient | null { - if (serverConfig.inference.openAIApiKey) { - return new OpenAIInferenceClient(); - } - - if (serverConfig.inference.ollamaBaseUrl) { - return new OllamaInferenceClient(); - } - return null; - } -} - -class OpenAIInferenceClient implements InferenceClient { - openAI: OpenAI; - - constructor() { - this.openAI = new OpenAI({ - apiKey: serverConfig.inference.openAIApiKey, - baseURL: serverConfig.inference.openAIBaseUrl, - }); - } - - async inferFromText(prompt: string): Promise { - const chatCompletion = await this.openAI.chat.completions.create({ - messages: [{ role: "user", content: prompt }], - model: serverConfig.inference.textModel, - response_format: { type: "json_object" }, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; - } - - async inferFromImage( - prompt: string, - contentType: string, - image: string, - ): Promise { - const chatCompletion = await this.openAI.chat.completions.create({ - model: serverConfig.inference.imageModel, - response_format: { type: "json_object" }, - messages: [ - { - role: "user", - content: [ - { type: "text", text: prompt }, - { - type: "image_url", - image_url: { - url: `data:${contentType};base64,${image}`, - detail: "low", - }, - }, - ], - }, - ], - max_tokens: 2000, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; - } -} - -class OllamaInferenceClient implements InferenceClient { - ollama: Ollama; - - constructor() { - this.ollama = new Ollama({ - host: serverConfig.inference.ollamaBaseUrl, - }); - } - - async runModel(model: string, prompt: string, image?: string) { - const chatCompletion = await this.ollama.chat({ - model: model, - format: "json", - stream: true, - keep_alive: serverConfig.inference.ollamaKeepAlive, - options: { - num_ctx: serverConfig.inference.contextLength, - }, - messages: [ - { role: "user", content: prompt, images: image ? [image] : undefined }, - ], - }); - - let totalTokens = 0; - let response = ""; - try { - for await (const part of chatCompletion) { - response += part.message.content; - if (!isNaN(part.eval_count)) { - totalTokens += part.eval_count; - } - if (!isNaN(part.prompt_eval_count)) { - totalTokens += part.prompt_eval_count; - } - } - } catch (e) { - // There seem to be some bug in ollama where you can get some successfull response, but still throw an error. - // Using stream + accumulating the response so far is a workaround. - // https://github.com/ollama/ollama-js/issues/72 - totalTokens = NaN; - logger.warn( - `Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`, - ); - } - - return { response, totalTokens }; - } - - async inferFromText(prompt: string): Promise { - return await this.runModel(serverConfig.inference.textModel, prompt); - } - - async inferFromImage( - prompt: string, - _contentType: string, - image: string, - ): Promise { - return await this.runModel( - serverConfig.inference.imageModel, - prompt, - image, - ); - } -} diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index f436f71b..b1394f73 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -1,6 +1,7 @@ import { and, Column, eq, inArray, sql } from "drizzle-orm"; import { z } from "zod"; +import type { InferenceClient } from "@hoarder/shared/inference"; import type { ZOpenAIRequest } from "@hoarder/shared/queues"; import { db } from "@hoarder/db"; import { @@ -13,6 +14,7 @@ import { import { DequeuedJob, Runner } from "@hoarder/queue"; import { readAsset } from "@hoarder/shared/assetdb"; import serverConfig from "@hoarder/shared/config"; +import { InferenceClientFactory } from "@hoarder/shared/inference"; import logger from "@hoarder/shared/logger"; import { buildImagePrompt, buildTextPrompt } from "@hoarder/shared/prompts"; import { @@ -21,8 +23,6 @@ import { zOpenAIRequestSchema, } from "@hoarder/shared/queues"; -import type { InferenceClient } from "./inference"; -import { InferenceClientFactory } from "./inference"; import { readImageText, readPDFText } from "./utils"; const openAIResponseSchema = z.object({ diff --git a/apps/workers/package.json b/apps/workers/package.json index 0ab7caa2..289f7315 100644 --- a/apps/workers/package.json +++ b/apps/workers/package.json @@ -26,8 +26,6 @@ "metascraper-title": "^5.45.22", "metascraper-twitter": "^5.45.6", "metascraper-url": "^5.45.22", - "ollama": "^0.5.9", - "openai": "^4.67.1", "pdf2json": "^3.0.5", "pdfjs-dist": "^4.0.379", "puppeteer": "^22.0.0", -- cgit v1.2.3-70-g09d2