From 2743d9e38ecfdbf757d4d2f97bcf09d601245b59 Mon Sep 17 00:00:00 2001 From: Mohamed Bassem Date: Sun, 18 May 2025 20:22:59 +0000 Subject: feat: Add AI auto summarization. Fixes #1163 --- apps/workers/openaiWorker.ts | 463 ------------------------------------------- 1 file changed, 463 deletions(-) delete mode 100644 apps/workers/openaiWorker.ts (limited to 'apps/workers/openaiWorker.ts') diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts deleted file mode 100644 index c8b2770e..00000000 --- a/apps/workers/openaiWorker.ts +++ /dev/null @@ -1,463 +0,0 @@ -import { and, Column, eq, inArray, sql } from "drizzle-orm"; -import { DequeuedJob, Runner } from "liteque"; -import { buildImpersonatingTRPCClient } from "trpc"; -import { z } from "zod"; - -import type { InferenceClient } from "@karakeep/shared/inference"; -import type { ZOpenAIRequest } from "@karakeep/shared/queues"; -import { db } from "@karakeep/db"; -import { - bookmarks, - bookmarkTags, - customPrompts, - tagsOnBookmarks, -} from "@karakeep/db/schema"; -import { readAsset } from "@karakeep/shared/assetdb"; -import serverConfig from "@karakeep/shared/config"; -import { InferenceClientFactory } from "@karakeep/shared/inference"; -import logger from "@karakeep/shared/logger"; -import { buildImagePrompt, buildTextPrompt } from "@karakeep/shared/prompts"; -import { - OpenAIQueue, - triggerRuleEngineOnEvent, - triggerSearchReindex, - triggerWebhook, - zOpenAIRequestSchema, -} from "@karakeep/shared/queues"; - -const openAIResponseSchema = z.object({ - tags: z.array(z.string()), -}); - -function tagNormalizer(col: Column) { - function normalizeTag(tag: string) { - return tag.toLowerCase().replace(/[ \-_]/g, ""); - } - - return { - normalizeTag, - sql: sql`lower(replace(replace(replace(${col}, ' ', ''), '-', ''), '_', ''))`, - }; -} - -async function attemptMarkTaggingStatus( - jobData: object | undefined, - status: "success" | "failure", -) { - if (!jobData) { - return; - } - try { - const request = zOpenAIRequestSchema.parse(jobData); - await db - .update(bookmarks) - .set({ - taggingStatus: status, - }) - .where(eq(bookmarks.id, request.bookmarkId)); - } catch (e) { - logger.error(`Something went wrong when marking the tagging status: ${e}`); - } -} - -export class OpenAiWorker { - static build() { - logger.info("Starting inference worker ..."); - const worker = new Runner( - OpenAIQueue, - { - run: runOpenAI, - onComplete: async (job) => { - const jobId = job.id; - logger.info(`[inference][${jobId}] Completed successfully`); - await attemptMarkTaggingStatus(job.data, "success"); - }, - onError: async (job) => { - const jobId = job.id; - logger.error( - `[inference][${jobId}] inference job failed: ${job.error}\n${job.error.stack}`, - ); - if (job.numRetriesLeft == 0) { - await attemptMarkTaggingStatus(job?.data, "failure"); - } - }, - }, - { - concurrency: 1, - pollIntervalMs: 1000, - timeoutSecs: serverConfig.inference.jobTimeoutSec, - }, - ); - - return worker; - } -} - -async function buildPrompt( - bookmark: NonNullable>>, -) { - const prompts = await fetchCustomPrompts(bookmark.userId, "text"); - if (bookmark.link) { - if (!bookmark.link.description && !bookmark.link.content) { - throw new Error( - `No content found for link "${bookmark.id}". Skipping ...`, - ); - } - - const content = bookmark.link.content; - return buildTextPrompt( - serverConfig.inference.inferredTagLang, - prompts, - `URL: ${bookmark.link.url} -Title: ${bookmark.link.title ?? ""} -Description: ${bookmark.link.description ?? ""} -Content: ${content ?? ""}`, - serverConfig.inference.contextLength, - ); - } - - if (bookmark.text) { - return buildTextPrompt( - serverConfig.inference.inferredTagLang, - prompts, - bookmark.text.text ?? "", - serverConfig.inference.contextLength, - ); - } - - throw new Error("Unknown bookmark type"); -} - -async function fetchBookmark(linkId: string) { - return await db.query.bookmarks.findFirst({ - where: eq(bookmarks.id, linkId), - with: { - link: true, - text: true, - asset: true, - }, - }); -} - -async function inferTagsFromImage( - jobId: string, - bookmark: NonNullable>>, - inferenceClient: InferenceClient, - abortSignal: AbortSignal, -) { - const { asset, metadata } = await readAsset({ - userId: bookmark.userId, - assetId: bookmark.asset.assetId, - }); - - if (!asset) { - throw new Error( - `[inference][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`, - ); - } - - const base64 = asset.toString("base64"); - return inferenceClient.inferFromImage( - buildImagePrompt( - serverConfig.inference.inferredTagLang, - await fetchCustomPrompts(bookmark.userId, "images"), - ), - metadata.contentType, - base64, - { schema: openAIResponseSchema, abortSignal }, - ); -} - -async function fetchCustomPrompts( - userId: string, - appliesTo: "text" | "images", -) { - const prompts = await db.query.customPrompts.findMany({ - where: and( - eq(customPrompts.userId, userId), - inArray(customPrompts.appliesTo, ["all_tagging", appliesTo]), - ), - columns: { - text: true, - }, - }); - - let promptTexts = prompts.map((p) => p.text); - if (containsTagsPlaceholder(prompts)) { - promptTexts = await replaceTagsPlaceholders(promptTexts, userId); - } - - return promptTexts; -} - -async function replaceTagsPlaceholders( - prompts: string[], - userId: string, -): Promise { - const api = await buildImpersonatingTRPCClient(userId); - const tags = (await api.tags.list()).tags; - const tagsString = `[${tags.map((tag) => tag.name).join(", ")}]`; - const aiTagsString = `[${tags - .filter((tag) => tag.numBookmarksByAttachedType.human ?? 0 == 0) - .map((tag) => tag.name) - .join(", ")}]`; - const userTagsString = `[${tags - .filter((tag) => tag.numBookmarksByAttachedType.human ?? 0 > 0) - .map((tag) => tag.name) - .join(", ")}]`; - - return prompts.map((p) => - p - .replaceAll("$tags", tagsString) - .replaceAll("$aiTags", aiTagsString) - .replaceAll("$userTags", userTagsString), - ); -} - -function containsTagsPlaceholder(prompts: { text: string }[]): boolean { - return ( - prompts.filter( - (p) => - p.text.includes("$tags") || - p.text.includes("$aiTags") || - p.text.includes("$userTags"), - ).length > 0 - ); -} - -async function inferTagsFromPDF( - _jobId: string, - bookmark: NonNullable>>, - inferenceClient: InferenceClient, - abortSignal: AbortSignal, -) { - const prompt = buildTextPrompt( - serverConfig.inference.inferredTagLang, - await fetchCustomPrompts(bookmark.userId, "text"), - `Content: ${bookmark.asset.content}`, - serverConfig.inference.contextLength, - ); - return inferenceClient.inferFromText(prompt, { - schema: openAIResponseSchema, - abortSignal, - }); -} - -async function inferTagsFromText( - bookmark: NonNullable>>, - inferenceClient: InferenceClient, - abortSignal: AbortSignal, -) { - return await inferenceClient.inferFromText(await buildPrompt(bookmark), { - schema: openAIResponseSchema, - abortSignal, - }); -} - -async function inferTags( - jobId: string, - bookmark: NonNullable>>, - inferenceClient: InferenceClient, - abortSignal: AbortSignal, -) { - let response; - if (bookmark.link || bookmark.text) { - response = await inferTagsFromText(bookmark, inferenceClient, abortSignal); - } else if (bookmark.asset) { - switch (bookmark.asset.assetType) { - case "image": - response = await inferTagsFromImage( - jobId, - bookmark, - inferenceClient, - abortSignal, - ); - break; - case "pdf": - response = await inferTagsFromPDF( - jobId, - bookmark, - inferenceClient, - abortSignal, - ); - break; - default: - throw new Error(`[inference][${jobId}] Unsupported bookmark type`); - } - } else { - throw new Error(`[inference][${jobId}] Unsupported bookmark type`); - } - - if (!response) { - throw new Error(`[inference][${jobId}] Inference response is empty`); - } - - try { - let tags = openAIResponseSchema.parse(JSON.parse(response.response)).tags; - logger.info( - `[inference][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`, - ); - - // Sometimes the tags contain the hashtag symbol, let's strip them out if they do. - // Additionally, trim the tags to prevent whitespaces at the beginning/the end of the tag. - tags = tags.map((t) => { - let tag = t; - if (tag.startsWith("#")) { - tag = t.slice(1); - } - return tag.trim(); - }); - - return tags; - } catch (e) { - const responseSneak = response.response.substring(0, 20); - throw new Error( - `[inference][${jobId}] The model ignored our prompt and didn't respond with the expected JSON: ${JSON.stringify(e)}. Here's a sneak peak from the response: ${responseSneak}`, - ); - } -} - -async function connectTags( - bookmarkId: string, - inferredTags: string[], - userId: string, -) { - if (inferredTags.length == 0) { - return; - } - - await db.transaction(async (tx) => { - // Attempt to match exiting tags with the new ones - const { matchedTagIds, notFoundTagNames } = await (async () => { - const { normalizeTag, sql: normalizedTagSql } = tagNormalizer( - bookmarkTags.name, - ); - const normalizedInferredTags = inferredTags.map((t) => ({ - originalTag: t, - normalizedTag: normalizeTag(t), - })); - - const matchedTags = await tx.query.bookmarkTags.findMany({ - where: and( - eq(bookmarkTags.userId, userId), - inArray( - normalizedTagSql, - normalizedInferredTags.map((t) => t.normalizedTag), - ), - ), - }); - - const matchedTagIds = matchedTags.map((r) => r.id); - const notFoundTagNames = normalizedInferredTags - .filter( - (t) => - !matchedTags.some( - (mt) => normalizeTag(mt.name) === t.normalizedTag, - ), - ) - .map((t) => t.originalTag); - - return { matchedTagIds, notFoundTagNames }; - })(); - - // Create tags that didn't exist previously - let newTagIds: string[] = []; - if (notFoundTagNames.length > 0) { - newTagIds = ( - await tx - .insert(bookmarkTags) - .values( - notFoundTagNames.map((t) => ({ - name: t, - userId, - })), - ) - .onConflictDoNothing() - .returning() - ).map((t) => t.id); - } - - // Delete old AI tags - const detachedTags = await tx - .delete(tagsOnBookmarks) - .where( - and( - eq(tagsOnBookmarks.attachedBy, "ai"), - eq(tagsOnBookmarks.bookmarkId, bookmarkId), - ), - ) - .returning(); - - const allTagIds = new Set([...matchedTagIds, ...newTagIds]); - - // Attach new ones - const attachedTags = await tx - .insert(tagsOnBookmarks) - .values( - [...allTagIds].map((tagId) => ({ - tagId, - bookmarkId, - attachedBy: "ai" as const, - })), - ) - .onConflictDoNothing() - .returning(); - - await triggerRuleEngineOnEvent(bookmarkId, [ - ...detachedTags.map((t) => ({ - type: "tagRemoved" as const, - tagId: t.tagId, - })), - ...attachedTags.map((t) => ({ - type: "tagAdded" as const, - tagId: t.tagId, - })), - ]); - }); -} - -async function runOpenAI(job: DequeuedJob) { - const jobId = job.id; - - const inferenceClient = InferenceClientFactory.build(); - if (!inferenceClient) { - logger.debug( - `[inference][${jobId}] No inference client configured, nothing to do now`, - ); - return; - } - - const request = zOpenAIRequestSchema.safeParse(job.data); - if (!request.success) { - throw new Error( - `[inference][${jobId}] Got malformed job request: ${request.error.toString()}`, - ); - } - - const { bookmarkId } = request.data; - const bookmark = await fetchBookmark(bookmarkId); - if (!bookmark) { - throw new Error( - `[inference][${jobId}] bookmark with id ${bookmarkId} was not found`, - ); - } - - logger.info( - `[inference][${jobId}] Starting an inference job for bookmark with id "${bookmark.id}"`, - ); - - const tags = await inferTags( - jobId, - bookmark, - inferenceClient, - job.abortSignal, - ); - - await connectTags(bookmarkId, tags, bookmark.userId); - - // Trigger a webhook - await triggerWebhook(bookmarkId, "ai tagged"); - - // Update the search index - await triggerSearchReindex(bookmarkId); -} -- cgit v1.2.3-70-g09d2