aboutsummaryrefslogtreecommitdiffstats
path: root/apps
diff options
context:
space:
mode:
Diffstat (limited to 'apps')
-rw-r--r--apps/workers/inference.ts155
-rw-r--r--apps/workers/openaiWorker.ts4
-rw-r--r--apps/workers/package.json2
3 files changed, 2 insertions, 159 deletions
diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts
deleted file mode 100644
index fed9478f..00000000
--- a/apps/workers/inference.ts
+++ /dev/null
@@ -1,155 +0,0 @@
-import { Ollama } from "ollama";
-import OpenAI from "openai";
-
-import serverConfig from "@hoarder/shared/config";
-import logger from "@hoarder/shared/logger";
-
-export interface InferenceResponse {
- response: string;
- totalTokens: number | undefined;
-}
-
-export interface InferenceClient {
- inferFromText(prompt: string): Promise<InferenceResponse>;
- inferFromImage(
- prompt: string,
- contentType: string,
- image: string,
- ): Promise<InferenceResponse>;
-}
-
-export class InferenceClientFactory {
- static build(): InferenceClient | null {
- if (serverConfig.inference.openAIApiKey) {
- return new OpenAIInferenceClient();
- }
-
- if (serverConfig.inference.ollamaBaseUrl) {
- return new OllamaInferenceClient();
- }
- return null;
- }
-}
-
-class OpenAIInferenceClient implements InferenceClient {
- openAI: OpenAI;
-
- constructor() {
- this.openAI = new OpenAI({
- apiKey: serverConfig.inference.openAIApiKey,
- baseURL: serverConfig.inference.openAIBaseUrl,
- });
- }
-
- async inferFromText(prompt: string): Promise<InferenceResponse> {
- const chatCompletion = await this.openAI.chat.completions.create({
- messages: [{ role: "user", content: prompt }],
- model: serverConfig.inference.textModel,
- response_format: { type: "json_object" },
- });
-
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error(`Got no message content from OpenAI`);
- }
- return { response, totalTokens: chatCompletion.usage?.total_tokens };
- }
-
- async inferFromImage(
- prompt: string,
- contentType: string,
- image: string,
- ): Promise<InferenceResponse> {
- const chatCompletion = await this.openAI.chat.completions.create({
- model: serverConfig.inference.imageModel,
- response_format: { type: "json_object" },
- messages: [
- {
- role: "user",
- content: [
- { type: "text", text: prompt },
- {
- type: "image_url",
- image_url: {
- url: `data:${contentType};base64,${image}`,
- detail: "low",
- },
- },
- ],
- },
- ],
- max_tokens: 2000,
- });
-
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error(`Got no message content from OpenAI`);
- }
- return { response, totalTokens: chatCompletion.usage?.total_tokens };
- }
-}
-
-class OllamaInferenceClient implements InferenceClient {
- ollama: Ollama;
-
- constructor() {
- this.ollama = new Ollama({
- host: serverConfig.inference.ollamaBaseUrl,
- });
- }
-
- async runModel(model: string, prompt: string, image?: string) {
- const chatCompletion = await this.ollama.chat({
- model: model,
- format: "json",
- stream: true,
- keep_alive: serverConfig.inference.ollamaKeepAlive,
- options: {
- num_ctx: serverConfig.inference.contextLength,
- },
- messages: [
- { role: "user", content: prompt, images: image ? [image] : undefined },
- ],
- });
-
- let totalTokens = 0;
- let response = "";
- try {
- for await (const part of chatCompletion) {
- response += part.message.content;
- if (!isNaN(part.eval_count)) {
- totalTokens += part.eval_count;
- }
- if (!isNaN(part.prompt_eval_count)) {
- totalTokens += part.prompt_eval_count;
- }
- }
- } catch (e) {
- // There seem to be some bug in ollama where you can get some successfull response, but still throw an error.
- // Using stream + accumulating the response so far is a workaround.
- // https://github.com/ollama/ollama-js/issues/72
- totalTokens = NaN;
- logger.warn(
- `Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`,
- );
- }
-
- return { response, totalTokens };
- }
-
- async inferFromText(prompt: string): Promise<InferenceResponse> {
- return await this.runModel(serverConfig.inference.textModel, prompt);
- }
-
- async inferFromImage(
- prompt: string,
- _contentType: string,
- image: string,
- ): Promise<InferenceResponse> {
- return await this.runModel(
- serverConfig.inference.imageModel,
- prompt,
- image,
- );
- }
-}
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts
index f436f71b..b1394f73 100644
--- a/apps/workers/openaiWorker.ts
+++ b/apps/workers/openaiWorker.ts
@@ -1,6 +1,7 @@
import { and, Column, eq, inArray, sql } from "drizzle-orm";
import { z } from "zod";
+import type { InferenceClient } from "@hoarder/shared/inference";
import type { ZOpenAIRequest } from "@hoarder/shared/queues";
import { db } from "@hoarder/db";
import {
@@ -13,6 +14,7 @@ import {
import { DequeuedJob, Runner } from "@hoarder/queue";
import { readAsset } from "@hoarder/shared/assetdb";
import serverConfig from "@hoarder/shared/config";
+import { InferenceClientFactory } from "@hoarder/shared/inference";
import logger from "@hoarder/shared/logger";
import { buildImagePrompt, buildTextPrompt } from "@hoarder/shared/prompts";
import {
@@ -21,8 +23,6 @@ import {
zOpenAIRequestSchema,
} from "@hoarder/shared/queues";
-import type { InferenceClient } from "./inference";
-import { InferenceClientFactory } from "./inference";
import { readImageText, readPDFText } from "./utils";
const openAIResponseSchema = z.object({
diff --git a/apps/workers/package.json b/apps/workers/package.json
index 0ab7caa2..289f7315 100644
--- a/apps/workers/package.json
+++ b/apps/workers/package.json
@@ -26,8 +26,6 @@
"metascraper-title": "^5.45.22",
"metascraper-twitter": "^5.45.6",
"metascraper-url": "^5.45.22",
- "ollama": "^0.5.9",
- "openai": "^4.67.1",
"pdf2json": "^3.0.5",
"pdfjs-dist": "^4.0.379",
"puppeteer": "^22.0.0",