diff options
| author | MohamedBassem <me@mbassem.com> | 2024-02-28 21:07:42 +0000 |
|---|---|---|
| committer | MohamedBassem <me@mbassem.com> | 2024-02-28 21:07:42 +0000 |
| commit | 3a1ee94d007dd22454c4ca9035abbb0eccdd2be3 (patch) | |
| tree | f920008d095908e6496743c7aadf072e824305d1 /packages | |
| parent | 3208dda3848ad739f54cebf44c423e2b68e85b2d (diff) | |
| download | karakeep-3a1ee94d007dd22454c4ca9035abbb0eccdd2be3.tar.zst | |
feature: Support tag inferance for note bookmarks
Diffstat (limited to 'packages')
| -rw-r--r-- | packages/web/server/api/routers/bookmarks.ts | 21 | ||||
| -rw-r--r-- | packages/workers/openai.ts | 84 |
2 files changed, 65 insertions, 40 deletions
diff --git a/packages/web/server/api/routers/bookmarks.ts b/packages/web/server/api/routers/bookmarks.ts index bfa0580f..4e98eb2f 100644 --- a/packages/web/server/api/routers/bookmarks.ts +++ b/packages/web/server/api/routers/bookmarks.ts @@ -17,7 +17,7 @@ import { bookmarks, tagsOnBookmarks, } from "@hoarder/db/schema"; -import { LinkCrawlerQueue } from "@hoarder/shared/queues"; +import { LinkCrawlerQueue, OpenAIQueue } from "@hoarder/shared/queues"; import { TRPCError, experimental_trpcMiddleware } from "@trpc/server"; import { and, desc, eq, inArray } from "drizzle-orm"; import { ZBookmarkTags } from "@/lib/types/api/tags"; @@ -157,10 +157,21 @@ export const bookmarksAppRouter = router({ ); // Enqueue crawling request - await LinkCrawlerQueue.add("crawl", { - bookmarkId: bookmark.id, - }); - + switch (bookmark.content.type) { + case "link": { + // The crawling job triggers openai when it's done + await LinkCrawlerQueue.add("crawl", { + bookmarkId: bookmark.id, + }); + break; + } + case "text": { + await OpenAIQueue.add("openai", { + bookmarkId: bookmark.id, + }); + break; + } + } return bookmark; }), diff --git a/packages/workers/openai.ts b/packages/workers/openai.ts index ed4c72e8..2d82a204 100644 --- a/packages/workers/openai.ts +++ b/packages/workers/openai.ts @@ -11,13 +11,8 @@ import { Job } from "bullmq"; import OpenAI from "openai"; import { z } from "zod"; import { Worker } from "bullmq"; -import { - bookmarkLinks, - bookmarkTags, - bookmarks, - tagsOnBookmarks, -} from "@hoarder/db/schema"; -import { eq } from "drizzle-orm"; +import { bookmarkTags, bookmarks, tagsOnBookmarks } from "@hoarder/db/schema"; +import { and, eq, inArray } from "drizzle-orm"; const openAIResponseSchema = z.object({ tags: z.array(z.string()), @@ -49,18 +44,41 @@ export class OpenAiWorker { } } -function buildPrompt(url: string, description: string) { - return ` - +const PROMPT_BASE = ` I'm building a read-it-later app and I need your help with automatic tagging. Please analyze the following text and suggest relevant tags that describe its key themes, topics, and main ideas. Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres. If it's a famous website you may also include a tag for the website. Tags should be lowercases and don't contain spaces. If the tag is not generic enough, don't include it. Aim for 3-5 tags. You must respond in JSON with the key "tags" and the value is list of tags. ----- -URL: ${url} -Description: ${description} +`; + +function buildPrompt( + bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, +) { + if (bookmark.link) { + if (!bookmark.link.description) { + throw new Error( + `No description found for link "${bookmark.id}". Skipping ...`, + ); + } + return ` +${PROMPT_BASE} +--- +URL: ${bookmark.link.url} +Description: ${bookmark.link.description} + `; + } + + if (bookmark.text) { + // TODO: Ensure that the content doesn't exceed the context length of openai + return ` +${PROMPT_BASE} +--- +Content: ${bookmark.text.text} `; + } + + throw new Error("Unknown bookmark type"); } async function fetchBookmark(linkId: string) { @@ -68,26 +86,18 @@ async function fetchBookmark(linkId: string) { where: eq(bookmarks.id, linkId), with: { link: true, + text: true, }, }); } async function inferTags( jobId: string, - link: typeof bookmarkLinks.$inferSelect, + bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, openai: OpenAI, ) { - const linkDescription = link?.description; - if (!linkDescription) { - throw new Error( - `[openai][${jobId}] No description found for link "${link.id}". Skipping ...`, - ); - } - const chatCompletion = await openai.chat.completions.create({ - messages: [ - { role: "system", content: buildPrompt(link.url, linkDescription) }, - ], + messages: [{ role: "system", content: buildPrompt(bookmark) }], model: "gpt-3.5-turbo-0125", response_format: { type: "json_object" }, }); @@ -100,7 +110,7 @@ async function inferTags( try { let tags = openAIResponseSchema.parse(JSON.parse(response)).tags; logger.info( - `[openai][${jobId}] Inferring tag for url "${link.url}" used ${chatCompletion.usage?.total_tokens} tokens and inferred: ${tags}`, + `[openai][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${chatCompletion.usage?.total_tokens} tokens and inferred: ${tags}`, ); // Sometimes the tags contain the hashtag symbol, let's strip them out if they do. @@ -120,7 +130,7 @@ async function inferTags( } async function createTags(tags: string[], userId: string) { - const res = await db + await db .insert(bookmarkTags) .values( tags.map((t) => ({ @@ -128,8 +138,18 @@ async function createTags(tags: string[], userId: string) { userId, })), ) - .onConflictDoNothing() - .returning({ id: bookmarkTags.id }); + .onConflictDoNothing(); + + const res = await db.query.bookmarkTags.findMany({ + where: and( + eq(bookmarkTags.userId, userId), + inArray(bookmarkTags.name, tags), + ), + columns: { + id: true, + }, + }); + return res.map((r) => r.id); } @@ -174,13 +194,7 @@ async function runOpenAI(job: Job<ZOpenAIRequest, void>) { ); } - if (!bookmark.link) { - throw new Error( - `[openai][${jobId}] bookmark with id ${bookmarkId} doesn't have a link`, - ); - } - - const tags = await inferTags(jobId, bookmark.link, openai); + const tags = await inferTags(jobId, bookmark, openai); const tagIds = await createTags(tags, bookmark.userId); await connectTags(bookmarkId, tagIds); |
