diff options
Diffstat (limited to 'apps')
| -rw-r--r-- | apps/workers/inference.ts | 125 | ||||
| -rw-r--r-- | apps/workers/openaiWorker.ts | 97 | ||||
| -rw-r--r-- | apps/workers/package.json | 1 | ||||
| -rw-r--r-- | apps/workers/searchWorker.ts | 21 |
4 files changed, 168 insertions, 76 deletions
diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts new file mode 100644 index 00000000..c622dd54 --- /dev/null +++ b/apps/workers/inference.ts @@ -0,0 +1,125 @@ +import { Ollama } from "ollama"; +import OpenAI from "openai"; + +import serverConfig from "@hoarder/shared/config"; + +export interface InferenceResponse { + response: string; + totalTokens: number | undefined; +} + +export interface InferenceClient { + inferFromText(prompt: string): Promise<InferenceResponse>; + inferFromImage( + prompt: string, + contentType: string, + image: string, + ): Promise<InferenceResponse>; +} + +export class InferenceClientFactory { + static build(): InferenceClient | null { + if (serverConfig.inference.openAIApiKey) { + return new OpenAIInferenceClient(); + } + + if (serverConfig.inference.ollamaBaseUrl) { + return new OllamaInferenceClient(); + } + return null; + } +} + +class OpenAIInferenceClient implements InferenceClient { + openAI: OpenAI; + + constructor() { + this.openAI = new OpenAI({ + apiKey: serverConfig.inference.openAIApiKey, + baseURL: serverConfig.inference.openAIBaseUrl, + }); + } + + async inferFromText(prompt: string): Promise<InferenceResponse> { + const chatCompletion = await this.openAI.chat.completions.create({ + messages: [{ role: "system", content: prompt }], + model: serverConfig.inference.textModel, + response_format: { type: "json_object" }, + }); + + const response = chatCompletion.choices[0].message.content; + if (!response) { + throw new Error(`Got no message content from OpenAI`); + } + return { response, totalTokens: chatCompletion.usage?.total_tokens }; + } + + async inferFromImage( + prompt: string, + contentType: string, + image: string, + ): Promise<InferenceResponse> { + const chatCompletion = await this.openAI.chat.completions.create({ + model: serverConfig.inference.imageModel, + messages: [ + { + role: "user", + content: [ + { type: "text", text: prompt }, + { + type: "image_url", + image_url: { + url: `data:${contentType};base64,${image}`, + detail: "low", + }, + }, + ], + }, + ], + max_tokens: 2000, + }); + + const response = chatCompletion.choices[0].message.content; + if (!response) { + throw new Error(`Got no message content from OpenAI`); + } + return { response, totalTokens: chatCompletion.usage?.total_tokens }; + } +} + +class OllamaInferenceClient implements InferenceClient { + ollama: Ollama; + + constructor() { + this.ollama = new Ollama({ + host: serverConfig.inference.ollamaBaseUrl, + }); + } + + async inferFromText(prompt: string): Promise<InferenceResponse> { + const chatCompletion = await this.ollama.chat({ + model: serverConfig.inference.textModel, + format: "json", + messages: [{ role: "system", content: prompt }], + }); + + const response = chatCompletion.message.content; + + return { response, totalTokens: chatCompletion.eval_count }; + } + + async inferFromImage( + prompt: string, + _contentType: string, + image: string, + ): Promise<InferenceResponse> { + const chatCompletion = await this.ollama.chat({ + model: serverConfig.inference.imageModel, + format: "json", + messages: [{ role: "user", content: prompt, images: [`${image}`] }], + }); + + const response = chatCompletion.message.content; + return { response, totalTokens: chatCompletion.eval_count }; + } +} diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index 5f785f2f..b706fb90 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -1,13 +1,11 @@ import { Job, Worker } from "bullmq"; import { and, eq, inArray } from "drizzle-orm"; -import OpenAI from "openai"; import { z } from "zod"; import { db } from "@hoarder/db"; import { bookmarks, bookmarkTags, tagsOnBookmarks } from "@hoarder/db/schema"; -import serverConfig from "@hoarder/shared/config"; -import logger from "@hoarder/shared/logger"; import { readAsset } from "@hoarder/shared/assetdb"; +import logger from "@hoarder/shared/logger"; import { OpenAIQueue, queueConnectionDetails, @@ -16,6 +14,8 @@ import { zOpenAIRequestSchema, } from "@hoarder/shared/queues"; +import { InferenceClientFactory, InferenceClient } from "./inference"; + const openAIResponseSchema = z.object({ tags: z.array(z.string()), }); @@ -41,8 +41,8 @@ async function attemptMarkTaggingStatus( } export class OpenAiWorker { - static build() { - logger.info("Starting openai worker ..."); + static async build() { + logger.info("Starting inference worker ..."); const worker = new Worker<ZOpenAIRequest, void>( OpenAIQueue.name, runOpenAI, @@ -54,13 +54,13 @@ export class OpenAiWorker { worker.on("completed", async (job): Promise<void> => { const jobId = job?.id ?? "unknown"; - logger.info(`[openai][${jobId}] Completed successfully`); + logger.info(`[inference][${jobId}] Completed successfully`); await attemptMarkTaggingStatus(job?.data, "success"); }); worker.on("failed", async (job, error): Promise<void> => { const jobId = job?.id ?? "unknown"; - logger.error(`[openai][${jobId}] openai job failed: ${error}`); + logger.error(`[inference][${jobId}] inference job failed: ${error}`); await attemptMarkTaggingStatus(job?.data, "failure"); }); @@ -138,82 +138,52 @@ async function fetchBookmark(linkId: string) { async function inferTagsFromImage( jobId: string, bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, - openai: OpenAI, + inferenceClient: InferenceClient, ) { - const { asset, metadata } = await readAsset({ userId: bookmark.userId, assetId: bookmark.asset.assetId, }); if (!asset) { - throw new Error(`[openai][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`); + throw new Error( + `[inference][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`, + ); } - const base64 = asset.toString('base64'); - - const chatCompletion = await openai.chat.completions.create({ - model: serverConfig.inference.imageModel, - messages: [ - { - role: "user", - content: [ - { type: "text", text: IMAGE_PROMPT_BASE }, - { - type: "image_url", - image_url: { - url: `data:${metadata.contentType};base64,${base64}`, - detail: "low", - }, - }, - ], - }, - ], - max_tokens: 2000, - }); + const base64 = asset.toString("base64"); - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`[openai][${jobId}] Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; + return await inferenceClient.inferFromImage( + IMAGE_PROMPT_BASE, + metadata.contentType, + base64, + ); } async function inferTagsFromText( - jobId: string, bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, - openai: OpenAI, + inferenceClient: InferenceClient, ) { - const chatCompletion = await openai.chat.completions.create({ - messages: [{ role: "system", content: buildPrompt(bookmark) }], - model: serverConfig.inference.textModel, - response_format: { type: "json_object" }, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`[openai][${jobId}] Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; + return await inferenceClient.inferFromText(buildPrompt(bookmark)); } async function inferTags( jobId: string, bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, - openai: OpenAI, + inferenceClient: InferenceClient, ) { let response; if (bookmark.link || bookmark.text) { - response = await inferTagsFromText(jobId, bookmark, openai); + response = await inferTagsFromText(bookmark, inferenceClient); } else if (bookmark.asset) { - response = await inferTagsFromImage(jobId, bookmark, openai); + response = await inferTagsFromImage(jobId, bookmark, inferenceClient); } else { - throw new Error(`[openai][${jobId}] Unsupported bookmark type`); + throw new Error(`[inference][${jobId}] Unsupported bookmark type`); } try { let tags = openAIResponseSchema.parse(JSON.parse(response.response)).tags; logger.info( - `[openai][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`, + `[inference][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`, ); // Sometimes the tags contain the hashtag symbol, let's strip them out if they do. @@ -227,7 +197,7 @@ async function inferTags( return tags; } catch (e) { throw new Error( - `[openai][${jobId}] Failed to parse JSON response from OpenAI: ${e}`, + `[inference][${jobId}] Failed to parse JSON response from inference client: ${e}`, ); } } @@ -292,23 +262,18 @@ async function connectTags( async function runOpenAI(job: Job<ZOpenAIRequest, void>) { const jobId = job.id ?? "unknown"; - const { inference } = serverConfig; - - if (!inference.openAIApiKey) { + const inferenceClient = InferenceClientFactory.build(); + if (!inferenceClient) { logger.debug( - `[openai][${jobId}] OpenAI is not configured, nothing to do now`, + `[inference][${jobId}] No inference client configured, nothing to do now`, ); return; } - const openai = new OpenAI({ - apiKey: inference.openAIApiKey, - }); - const request = zOpenAIRequestSchema.safeParse(job.data); if (!request.success) { throw new Error( - `[openai][${jobId}] Got malformed job request: ${request.error.toString()}`, + `[inference][${jobId}] Got malformed job request: ${request.error.toString()}`, ); } @@ -316,11 +281,11 @@ async function runOpenAI(job: Job<ZOpenAIRequest, void>) { const bookmark = await fetchBookmark(bookmarkId); if (!bookmark) { throw new Error( - `[openai][${jobId}] bookmark with id ${bookmarkId} was not found`, + `[inference][${jobId}] bookmark with id ${bookmarkId} was not found`, ); } - const tags = await inferTags(jobId, bookmark, openai); + const tags = await inferTags(jobId, bookmark, inferenceClient); await connectTags(bookmarkId, tags, bookmark.userId); diff --git a/apps/workers/package.json b/apps/workers/package.json index f6d58eb4..27a02f88 100644 --- a/apps/workers/package.json +++ b/apps/workers/package.json @@ -24,6 +24,7 @@ "metascraper-title": "^5.43.4", "metascraper-twitter": "^5.43.4", "metascraper-url": "^5.43.4", + "ollama": "^0.5.0", "openai": "^4.29.0", "puppeteer": "^22.0.0", "puppeteer-extra": "^3.3.6", diff --git a/apps/workers/searchWorker.ts b/apps/workers/searchWorker.ts index 618e7c89..b24777d7 100644 --- a/apps/workers/searchWorker.ts +++ b/apps/workers/searchWorker.ts @@ -1,16 +1,17 @@ +import type { Job } from "bullmq"; +import { Worker } from "bullmq"; +import { eq } from "drizzle-orm"; + +import type { ZSearchIndexingRequest } from "@hoarder/shared/queues"; import { db } from "@hoarder/db"; +import { bookmarks } from "@hoarder/db/schema"; import logger from "@hoarder/shared/logger"; -import { getSearchIdxClient } from "@hoarder/shared/search"; import { - SearchIndexingQueue, - ZSearchIndexingRequest, queueConnectionDetails, + SearchIndexingQueue, zSearchIndexingRequestSchema, } from "@hoarder/shared/queues"; -import { Job } from "bullmq"; -import { Worker } from "bullmq"; -import { bookmarks } from "@hoarder/db/schema"; -import { eq } from "drizzle-orm"; +import { getSearchIdxClient } from "@hoarder/shared/search"; export class SearchIndexingWorker { static async build() { @@ -25,12 +26,12 @@ export class SearchIndexingWorker { ); worker.on("completed", (job) => { - const jobId = job?.id || "unknown"; + const jobId = job?.id ?? "unknown"; logger.info(`[search][${jobId}] Completed successfully`); }); worker.on("failed", (job, error) => { - const jobId = job?.id || "unknown"; + const jobId = job?.id ?? "unknown"; logger.error(`[search][${jobId}] openai job failed: ${error}`); }); @@ -85,7 +86,7 @@ async function runDelete( } async function runSearchIndexing(job: Job<ZSearchIndexingRequest, void>) { - const jobId = job.id || "unknown"; + const jobId = job.id ?? "unknown"; const request = zSearchIndexingRequestSchema.safeParse(job.data); if (!request.success) { |
