From 3e727f7ba3ad157ca1ccc6100711266cae1bde23 Mon Sep 17 00:00:00 2001 From: Mohamed Bassem Date: Sat, 26 Oct 2024 20:07:16 +0000 Subject: refactor: Move inference to the shared package --- apps/workers/inference.ts | 155 ------------------------------------------- apps/workers/openaiWorker.ts | 4 +- apps/workers/package.json | 2 - packages/shared/inference.ts | 155 +++++++++++++++++++++++++++++++++++++++++++ packages/shared/package.json | 2 + pnpm-lock.yaml | 13 ++-- 6 files changed, 166 insertions(+), 165 deletions(-) delete mode 100644 apps/workers/inference.ts create mode 100644 packages/shared/inference.ts diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts deleted file mode 100644 index fed9478f..00000000 --- a/apps/workers/inference.ts +++ /dev/null @@ -1,155 +0,0 @@ -import { Ollama } from "ollama"; -import OpenAI from "openai"; - -import serverConfig from "@hoarder/shared/config"; -import logger from "@hoarder/shared/logger"; - -export interface InferenceResponse { - response: string; - totalTokens: number | undefined; -} - -export interface InferenceClient { - inferFromText(prompt: string): Promise; - inferFromImage( - prompt: string, - contentType: string, - image: string, - ): Promise; -} - -export class InferenceClientFactory { - static build(): InferenceClient | null { - if (serverConfig.inference.openAIApiKey) { - return new OpenAIInferenceClient(); - } - - if (serverConfig.inference.ollamaBaseUrl) { - return new OllamaInferenceClient(); - } - return null; - } -} - -class OpenAIInferenceClient implements InferenceClient { - openAI: OpenAI; - - constructor() { - this.openAI = new OpenAI({ - apiKey: serverConfig.inference.openAIApiKey, - baseURL: serverConfig.inference.openAIBaseUrl, - }); - } - - async inferFromText(prompt: string): Promise { - const chatCompletion = await this.openAI.chat.completions.create({ - messages: [{ role: "user", content: prompt }], - model: serverConfig.inference.textModel, - response_format: { type: "json_object" }, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; - } - - async inferFromImage( - prompt: string, - contentType: string, - image: string, - ): Promise { - const chatCompletion = await this.openAI.chat.completions.create({ - model: serverConfig.inference.imageModel, - response_format: { type: "json_object" }, - messages: [ - { - role: "user", - content: [ - { type: "text", text: prompt }, - { - type: "image_url", - image_url: { - url: `data:${contentType};base64,${image}`, - detail: "low", - }, - }, - ], - }, - ], - max_tokens: 2000, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; - } -} - -class OllamaInferenceClient implements InferenceClient { - ollama: Ollama; - - constructor() { - this.ollama = new Ollama({ - host: serverConfig.inference.ollamaBaseUrl, - }); - } - - async runModel(model: string, prompt: string, image?: string) { - const chatCompletion = await this.ollama.chat({ - model: model, - format: "json", - stream: true, - keep_alive: serverConfig.inference.ollamaKeepAlive, - options: { - num_ctx: serverConfig.inference.contextLength, - }, - messages: [ - { role: "user", content: prompt, images: image ? [image] : undefined }, - ], - }); - - let totalTokens = 0; - let response = ""; - try { - for await (const part of chatCompletion) { - response += part.message.content; - if (!isNaN(part.eval_count)) { - totalTokens += part.eval_count; - } - if (!isNaN(part.prompt_eval_count)) { - totalTokens += part.prompt_eval_count; - } - } - } catch (e) { - // There seem to be some bug in ollama where you can get some successfull response, but still throw an error. - // Using stream + accumulating the response so far is a workaround. - // https://github.com/ollama/ollama-js/issues/72 - totalTokens = NaN; - logger.warn( - `Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`, - ); - } - - return { response, totalTokens }; - } - - async inferFromText(prompt: string): Promise { - return await this.runModel(serverConfig.inference.textModel, prompt); - } - - async inferFromImage( - prompt: string, - _contentType: string, - image: string, - ): Promise { - return await this.runModel( - serverConfig.inference.imageModel, - prompt, - image, - ); - } -} diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index f436f71b..b1394f73 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -1,6 +1,7 @@ import { and, Column, eq, inArray, sql } from "drizzle-orm"; import { z } from "zod"; +import type { InferenceClient } from "@hoarder/shared/inference"; import type { ZOpenAIRequest } from "@hoarder/shared/queues"; import { db } from "@hoarder/db"; import { @@ -13,6 +14,7 @@ import { import { DequeuedJob, Runner } from "@hoarder/queue"; import { readAsset } from "@hoarder/shared/assetdb"; import serverConfig from "@hoarder/shared/config"; +import { InferenceClientFactory } from "@hoarder/shared/inference"; import logger from "@hoarder/shared/logger"; import { buildImagePrompt, buildTextPrompt } from "@hoarder/shared/prompts"; import { @@ -21,8 +23,6 @@ import { zOpenAIRequestSchema, } from "@hoarder/shared/queues"; -import type { InferenceClient } from "./inference"; -import { InferenceClientFactory } from "./inference"; import { readImageText, readPDFText } from "./utils"; const openAIResponseSchema = z.object({ diff --git a/apps/workers/package.json b/apps/workers/package.json index 0ab7caa2..289f7315 100644 --- a/apps/workers/package.json +++ b/apps/workers/package.json @@ -26,8 +26,6 @@ "metascraper-title": "^5.45.22", "metascraper-twitter": "^5.45.6", "metascraper-url": "^5.45.22", - "ollama": "^0.5.9", - "openai": "^4.67.1", "pdf2json": "^3.0.5", "pdfjs-dist": "^4.0.379", "puppeteer": "^22.0.0", diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts new file mode 100644 index 00000000..f34c2880 --- /dev/null +++ b/packages/shared/inference.ts @@ -0,0 +1,155 @@ +import { Ollama } from "ollama"; +import OpenAI from "openai"; + +import serverConfig from "./config"; +import logger from "./logger"; + +export interface InferenceResponse { + response: string; + totalTokens: number | undefined; +} + +export interface InferenceClient { + inferFromText(prompt: string): Promise; + inferFromImage( + prompt: string, + contentType: string, + image: string, + ): Promise; +} + +export class InferenceClientFactory { + static build(): InferenceClient | null { + if (serverConfig.inference.openAIApiKey) { + return new OpenAIInferenceClient(); + } + + if (serverConfig.inference.ollamaBaseUrl) { + return new OllamaInferenceClient(); + } + return null; + } +} + +class OpenAIInferenceClient implements InferenceClient { + openAI: OpenAI; + + constructor() { + this.openAI = new OpenAI({ + apiKey: serverConfig.inference.openAIApiKey, + baseURL: serverConfig.inference.openAIBaseUrl, + }); + } + + async inferFromText(prompt: string): Promise { + const chatCompletion = await this.openAI.chat.completions.create({ + messages: [{ role: "user", content: prompt }], + model: serverConfig.inference.textModel, + response_format: { type: "json_object" }, + }); + + const response = chatCompletion.choices[0].message.content; + if (!response) { + throw new Error(`Got no message content from OpenAI`); + } + return { response, totalTokens: chatCompletion.usage?.total_tokens }; + } + + async inferFromImage( + prompt: string, + contentType: string, + image: string, + ): Promise { + const chatCompletion = await this.openAI.chat.completions.create({ + model: serverConfig.inference.imageModel, + response_format: { type: "json_object" }, + messages: [ + { + role: "user", + content: [ + { type: "text", text: prompt }, + { + type: "image_url", + image_url: { + url: `data:${contentType};base64,${image}`, + detail: "low", + }, + }, + ], + }, + ], + max_tokens: 2000, + }); + + const response = chatCompletion.choices[0].message.content; + if (!response) { + throw new Error(`Got no message content from OpenAI`); + } + return { response, totalTokens: chatCompletion.usage?.total_tokens }; + } +} + +class OllamaInferenceClient implements InferenceClient { + ollama: Ollama; + + constructor() { + this.ollama = new Ollama({ + host: serverConfig.inference.ollamaBaseUrl, + }); + } + + async runModel(model: string, prompt: string, image?: string) { + const chatCompletion = await this.ollama.chat({ + model: model, + format: "json", + stream: true, + keep_alive: serverConfig.inference.ollamaKeepAlive, + options: { + num_ctx: serverConfig.inference.contextLength, + }, + messages: [ + { role: "user", content: prompt, images: image ? [image] : undefined }, + ], + }); + + let totalTokens = 0; + let response = ""; + try { + for await (const part of chatCompletion) { + response += part.message.content; + if (!isNaN(part.eval_count)) { + totalTokens += part.eval_count; + } + if (!isNaN(part.prompt_eval_count)) { + totalTokens += part.prompt_eval_count; + } + } + } catch (e) { + // There seem to be some bug in ollama where you can get some successfull response, but still throw an error. + // Using stream + accumulating the response so far is a workaround. + // https://github.com/ollama/ollama-js/issues/72 + totalTokens = NaN; + logger.warn( + `Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`, + ); + } + + return { response, totalTokens }; + } + + async inferFromText(prompt: string): Promise { + return await this.runModel(serverConfig.inference.textModel, prompt); + } + + async inferFromImage( + prompt: string, + _contentType: string, + image: string, + ): Promise { + return await this.runModel( + serverConfig.inference.imageModel, + prompt, + image, + ); + } +} diff --git a/packages/shared/package.json b/packages/shared/package.json index 69d93075..f6774263 100644 --- a/packages/shared/package.json +++ b/packages/shared/package.json @@ -8,6 +8,8 @@ "@hoarder/queue": "workspace:^0.1.0", "glob": "^11.0.0", "meilisearch": "^0.37.0", + "ollama": "^0.5.9", + "openai": "^4.67.1", "winston": "^3.11.0", "zod": "^3.22.4" }, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d6638882..04cb0c62 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -749,12 +749,6 @@ importers: metascraper-url: specifier: ^5.45.22 version: 5.45.22 - ollama: - specifier: ^0.5.9 - version: 0.5.9 - openai: - specifier: ^4.67.1 - version: 4.67.1(zod@3.22.4) pdf2json: specifier: ^3.0.5 version: 3.0.5 @@ -955,6 +949,12 @@ importers: meilisearch: specifier: ^0.37.0 version: 0.37.0 + ollama: + specifier: ^0.5.9 + version: 0.5.9 + openai: + specifier: ^4.67.1 + version: 4.67.1(zod@3.22.4) winston: specifier: ^3.11.0 version: 3.11.0 @@ -4709,6 +4709,7 @@ packages: '@xmldom/xmldom@0.7.13': resolution: {integrity: sha512-lm2GW5PkosIzccsaZIz7tp8cPADSIlIHWDFTR1N0SzfinhhYgeIQjFMz4rYzanCScr3DqQLeomUDArp6MWKm+g==} engines: {node: '>=10.0.0'} + deprecated: this version is no longer supported, please update to at least 0.8.* '@xmldom/xmldom@0.8.10': resolution: {integrity: sha512-2WALfTl4xo2SkGCYRt6rDTFfk9R1czmBvUQy12gK2KuRKIpWEhcbbzy8EZXtz/jkRqHX8bFEc6FC1HjX4TUWYw==} -- cgit v1.2.3-70-g09d2