aboutsummaryrefslogtreecommitdiffstats
path: root/apps/workers/openaiWorker.ts
diff options
context:
space:
mode:
Diffstat (limited to 'apps/workers/openaiWorker.ts')
-rw-r--r--apps/workers/openaiWorker.ts97
1 files changed, 31 insertions, 66 deletions
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts
index 5f785f2f..b706fb90 100644
--- a/apps/workers/openaiWorker.ts
+++ b/apps/workers/openaiWorker.ts
@@ -1,13 +1,11 @@
import { Job, Worker } from "bullmq";
import { and, eq, inArray } from "drizzle-orm";
-import OpenAI from "openai";
import { z } from "zod";
import { db } from "@hoarder/db";
import { bookmarks, bookmarkTags, tagsOnBookmarks } from "@hoarder/db/schema";
-import serverConfig from "@hoarder/shared/config";
-import logger from "@hoarder/shared/logger";
import { readAsset } from "@hoarder/shared/assetdb";
+import logger from "@hoarder/shared/logger";
import {
OpenAIQueue,
queueConnectionDetails,
@@ -16,6 +14,8 @@ import {
zOpenAIRequestSchema,
} from "@hoarder/shared/queues";
+import { InferenceClientFactory, InferenceClient } from "./inference";
+
const openAIResponseSchema = z.object({
tags: z.array(z.string()),
});
@@ -41,8 +41,8 @@ async function attemptMarkTaggingStatus(
}
export class OpenAiWorker {
- static build() {
- logger.info("Starting openai worker ...");
+ static async build() {
+ logger.info("Starting inference worker ...");
const worker = new Worker<ZOpenAIRequest, void>(
OpenAIQueue.name,
runOpenAI,
@@ -54,13 +54,13 @@ export class OpenAiWorker {
worker.on("completed", async (job): Promise<void> => {
const jobId = job?.id ?? "unknown";
- logger.info(`[openai][${jobId}] Completed successfully`);
+ logger.info(`[inference][${jobId}] Completed successfully`);
await attemptMarkTaggingStatus(job?.data, "success");
});
worker.on("failed", async (job, error): Promise<void> => {
const jobId = job?.id ?? "unknown";
- logger.error(`[openai][${jobId}] openai job failed: ${error}`);
+ logger.error(`[inference][${jobId}] inference job failed: ${error}`);
await attemptMarkTaggingStatus(job?.data, "failure");
});
@@ -138,82 +138,52 @@ async function fetchBookmark(linkId: string) {
async function inferTagsFromImage(
jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
- openai: OpenAI,
+ inferenceClient: InferenceClient,
) {
-
const { asset, metadata } = await readAsset({
userId: bookmark.userId,
assetId: bookmark.asset.assetId,
});
if (!asset) {
- throw new Error(`[openai][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`);
+ throw new Error(
+ `[inference][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`,
+ );
}
- const base64 = asset.toString('base64');
-
- const chatCompletion = await openai.chat.completions.create({
- model: serverConfig.inference.imageModel,
- messages: [
- {
- role: "user",
- content: [
- { type: "text", text: IMAGE_PROMPT_BASE },
- {
- type: "image_url",
- image_url: {
- url: `data:${metadata.contentType};base64,${base64}`,
- detail: "low",
- },
- },
- ],
- },
- ],
- max_tokens: 2000,
- });
+ const base64 = asset.toString("base64");
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error(`[openai][${jobId}] Got no message content from OpenAI`);
- }
- return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ return await inferenceClient.inferFromImage(
+ IMAGE_PROMPT_BASE,
+ metadata.contentType,
+ base64,
+ );
}
async function inferTagsFromText(
- jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
- openai: OpenAI,
+ inferenceClient: InferenceClient,
) {
- const chatCompletion = await openai.chat.completions.create({
- messages: [{ role: "system", content: buildPrompt(bookmark) }],
- model: serverConfig.inference.textModel,
- response_format: { type: "json_object" },
- });
-
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error(`[openai][${jobId}] Got no message content from OpenAI`);
- }
- return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ return await inferenceClient.inferFromText(buildPrompt(bookmark));
}
async function inferTags(
jobId: string,
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
- openai: OpenAI,
+ inferenceClient: InferenceClient,
) {
let response;
if (bookmark.link || bookmark.text) {
- response = await inferTagsFromText(jobId, bookmark, openai);
+ response = await inferTagsFromText(bookmark, inferenceClient);
} else if (bookmark.asset) {
- response = await inferTagsFromImage(jobId, bookmark, openai);
+ response = await inferTagsFromImage(jobId, bookmark, inferenceClient);
} else {
- throw new Error(`[openai][${jobId}] Unsupported bookmark type`);
+ throw new Error(`[inference][${jobId}] Unsupported bookmark type`);
}
try {
let tags = openAIResponseSchema.parse(JSON.parse(response.response)).tags;
logger.info(
- `[openai][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`,
+ `[inference][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`,
);
// Sometimes the tags contain the hashtag symbol, let's strip them out if they do.
@@ -227,7 +197,7 @@ async function inferTags(
return tags;
} catch (e) {
throw new Error(
- `[openai][${jobId}] Failed to parse JSON response from OpenAI: ${e}`,
+ `[inference][${jobId}] Failed to parse JSON response from inference client: ${e}`,
);
}
}
@@ -292,23 +262,18 @@ async function connectTags(
async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
const jobId = job.id ?? "unknown";
- const { inference } = serverConfig;
-
- if (!inference.openAIApiKey) {
+ const inferenceClient = InferenceClientFactory.build();
+ if (!inferenceClient) {
logger.debug(
- `[openai][${jobId}] OpenAI is not configured, nothing to do now`,
+ `[inference][${jobId}] No inference client configured, nothing to do now`,
);
return;
}
- const openai = new OpenAI({
- apiKey: inference.openAIApiKey,
- });
-
const request = zOpenAIRequestSchema.safeParse(job.data);
if (!request.success) {
throw new Error(
- `[openai][${jobId}] Got malformed job request: ${request.error.toString()}`,
+ `[inference][${jobId}] Got malformed job request: ${request.error.toString()}`,
);
}
@@ -316,11 +281,11 @@ async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
const bookmark = await fetchBookmark(bookmarkId);
if (!bookmark) {
throw new Error(
- `[openai][${jobId}] bookmark with id ${bookmarkId} was not found`,
+ `[inference][${jobId}] bookmark with id ${bookmarkId} was not found`,
);
}
- const tags = await inferTags(jobId, bookmark, openai);
+ const tags = await inferTags(jobId, bookmark, inferenceClient);
await connectTags(bookmarkId, tags, bookmark.userId);