diff options
Diffstat (limited to 'apps/workers/openaiWorker.ts')
| -rw-r--r-- | apps/workers/openaiWorker.ts | 97 |
1 files changed, 31 insertions, 66 deletions
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index 5f785f2f..b706fb90 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -1,13 +1,11 @@ import { Job, Worker } from "bullmq"; import { and, eq, inArray } from "drizzle-orm"; -import OpenAI from "openai"; import { z } from "zod"; import { db } from "@hoarder/db"; import { bookmarks, bookmarkTags, tagsOnBookmarks } from "@hoarder/db/schema"; -import serverConfig from "@hoarder/shared/config"; -import logger from "@hoarder/shared/logger"; import { readAsset } from "@hoarder/shared/assetdb"; +import logger from "@hoarder/shared/logger"; import { OpenAIQueue, queueConnectionDetails, @@ -16,6 +14,8 @@ import { zOpenAIRequestSchema, } from "@hoarder/shared/queues"; +import { InferenceClientFactory, InferenceClient } from "./inference"; + const openAIResponseSchema = z.object({ tags: z.array(z.string()), }); @@ -41,8 +41,8 @@ async function attemptMarkTaggingStatus( } export class OpenAiWorker { - static build() { - logger.info("Starting openai worker ..."); + static async build() { + logger.info("Starting inference worker ..."); const worker = new Worker<ZOpenAIRequest, void>( OpenAIQueue.name, runOpenAI, @@ -54,13 +54,13 @@ export class OpenAiWorker { worker.on("completed", async (job): Promise<void> => { const jobId = job?.id ?? "unknown"; - logger.info(`[openai][${jobId}] Completed successfully`); + logger.info(`[inference][${jobId}] Completed successfully`); await attemptMarkTaggingStatus(job?.data, "success"); }); worker.on("failed", async (job, error): Promise<void> => { const jobId = job?.id ?? "unknown"; - logger.error(`[openai][${jobId}] openai job failed: ${error}`); + logger.error(`[inference][${jobId}] inference job failed: ${error}`); await attemptMarkTaggingStatus(job?.data, "failure"); }); @@ -138,82 +138,52 @@ async function fetchBookmark(linkId: string) { async function inferTagsFromImage( jobId: string, bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, - openai: OpenAI, + inferenceClient: InferenceClient, ) { - const { asset, metadata } = await readAsset({ userId: bookmark.userId, assetId: bookmark.asset.assetId, }); if (!asset) { - throw new Error(`[openai][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`); + throw new Error( + `[inference][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`, + ); } - const base64 = asset.toString('base64'); - - const chatCompletion = await openai.chat.completions.create({ - model: serverConfig.inference.imageModel, - messages: [ - { - role: "user", - content: [ - { type: "text", text: IMAGE_PROMPT_BASE }, - { - type: "image_url", - image_url: { - url: `data:${metadata.contentType};base64,${base64}`, - detail: "low", - }, - }, - ], - }, - ], - max_tokens: 2000, - }); + const base64 = asset.toString("base64"); - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`[openai][${jobId}] Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; + return await inferenceClient.inferFromImage( + IMAGE_PROMPT_BASE, + metadata.contentType, + base64, + ); } async function inferTagsFromText( - jobId: string, bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, - openai: OpenAI, + inferenceClient: InferenceClient, ) { - const chatCompletion = await openai.chat.completions.create({ - messages: [{ role: "system", content: buildPrompt(bookmark) }], - model: serverConfig.inference.textModel, - response_format: { type: "json_object" }, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error(`[openai][${jobId}] Got no message content from OpenAI`); - } - return { response, totalTokens: chatCompletion.usage?.total_tokens }; + return await inferenceClient.inferFromText(buildPrompt(bookmark)); } async function inferTags( jobId: string, bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, - openai: OpenAI, + inferenceClient: InferenceClient, ) { let response; if (bookmark.link || bookmark.text) { - response = await inferTagsFromText(jobId, bookmark, openai); + response = await inferTagsFromText(bookmark, inferenceClient); } else if (bookmark.asset) { - response = await inferTagsFromImage(jobId, bookmark, openai); + response = await inferTagsFromImage(jobId, bookmark, inferenceClient); } else { - throw new Error(`[openai][${jobId}] Unsupported bookmark type`); + throw new Error(`[inference][${jobId}] Unsupported bookmark type`); } try { let tags = openAIResponseSchema.parse(JSON.parse(response.response)).tags; logger.info( - `[openai][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`, + `[inference][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`, ); // Sometimes the tags contain the hashtag symbol, let's strip them out if they do. @@ -227,7 +197,7 @@ async function inferTags( return tags; } catch (e) { throw new Error( - `[openai][${jobId}] Failed to parse JSON response from OpenAI: ${e}`, + `[inference][${jobId}] Failed to parse JSON response from inference client: ${e}`, ); } } @@ -292,23 +262,18 @@ async function connectTags( async function runOpenAI(job: Job<ZOpenAIRequest, void>) { const jobId = job.id ?? "unknown"; - const { inference } = serverConfig; - - if (!inference.openAIApiKey) { + const inferenceClient = InferenceClientFactory.build(); + if (!inferenceClient) { logger.debug( - `[openai][${jobId}] OpenAI is not configured, nothing to do now`, + `[inference][${jobId}] No inference client configured, nothing to do now`, ); return; } - const openai = new OpenAI({ - apiKey: inference.openAIApiKey, - }); - const request = zOpenAIRequestSchema.safeParse(job.data); if (!request.success) { throw new Error( - `[openai][${jobId}] Got malformed job request: ${request.error.toString()}`, + `[inference][${jobId}] Got malformed job request: ${request.error.toString()}`, ); } @@ -316,11 +281,11 @@ async function runOpenAI(job: Job<ZOpenAIRequest, void>) { const bookmark = await fetchBookmark(bookmarkId); if (!bookmark) { throw new Error( - `[openai][${jobId}] bookmark with id ${bookmarkId} was not found`, + `[inference][${jobId}] bookmark with id ${bookmarkId} was not found`, ); } - const tags = await inferTags(jobId, bookmark, openai); + const tags = await inferTags(jobId, bookmark, inferenceClient); await connectTags(bookmarkId, tags, bookmark.userId); |
