aboutsummaryrefslogtreecommitdiffstats
path: root/packages
diff options
context:
space:
mode:
authorMohamedBassem <me@mbassem.com>2024-02-28 21:07:42 +0000
committerMohamedBassem <me@mbassem.com>2024-02-28 21:07:42 +0000
commit3a1ee94d007dd22454c4ca9035abbb0eccdd2be3 (patch)
treef920008d095908e6496743c7aadf072e824305d1 /packages
parent3208dda3848ad739f54cebf44c423e2b68e85b2d (diff)
downloadkarakeep-3a1ee94d007dd22454c4ca9035abbb0eccdd2be3.tar.zst
feature: Support tag inferance for note bookmarks
Diffstat (limited to 'packages')
-rw-r--r--packages/web/server/api/routers/bookmarks.ts21
-rw-r--r--packages/workers/openai.ts84
2 files changed, 65 insertions, 40 deletions
diff --git a/packages/web/server/api/routers/bookmarks.ts b/packages/web/server/api/routers/bookmarks.ts
index bfa0580f..4e98eb2f 100644
--- a/packages/web/server/api/routers/bookmarks.ts
+++ b/packages/web/server/api/routers/bookmarks.ts
@@ -17,7 +17,7 @@ import {
bookmarks,
tagsOnBookmarks,
} from "@hoarder/db/schema";
-import { LinkCrawlerQueue } from "@hoarder/shared/queues";
+import { LinkCrawlerQueue, OpenAIQueue } from "@hoarder/shared/queues";
import { TRPCError, experimental_trpcMiddleware } from "@trpc/server";
import { and, desc, eq, inArray } from "drizzle-orm";
import { ZBookmarkTags } from "@/lib/types/api/tags";
@@ -157,10 +157,21 @@ export const bookmarksAppRouter = router({
);
// Enqueue crawling request
- await LinkCrawlerQueue.add("crawl", {
- bookmarkId: bookmark.id,
- });
-
+ switch (bookmark.content.type) {
+ case "link": {
+ // The crawling job triggers openai when it's done
+ await LinkCrawlerQueue.add("crawl", {
+ bookmarkId: bookmark.id,
+ });
+ break;
+ }
+ case "text": {
+ await OpenAIQueue.add("openai", {
+ bookmarkId: bookmark.id,
+ });
+ break;
+ }
+ }
return bookmark;
}),
diff --git a/packages/workers/openai.ts b/packages/workers/openai.ts
index ed4c72e8..2d82a204 100644
--- a/packages/workers/openai.ts
+++ b/packages/workers/openai.ts
@@ -11,13 +11,8 @@ import { Job } from "bullmq";
import OpenAI from "openai";
import { z } from "zod";
import { Worker } from "bullmq";
-import {
- bookmarkLinks,
- bookmarkTags,
- bookmarks,
- tagsOnBookmarks,
-} from "@hoarder/db/schema";
-import { eq } from "drizzle-orm";
+import { bookmarkTags, bookmarks, tagsOnBookmarks } from "@hoarder/db/schema";
+import { and, eq, inArray } from "drizzle-orm";
const openAIResponseSchema = z.object({
tags: z.array(z.string()),
@@ -49,18 +44,41 @@ export class OpenAiWorker {
}
}
-function buildPrompt(url: string, description: string) {
- return `
-
+const PROMPT_BASE = `
I'm building a read-it-later app and I need your help with automatic tagging.
Please analyze the following text and suggest relevant tags that describe its key themes, topics, and main ideas.
Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres. If it's a famous website
you may also include a tag for the website. Tags should be lowercases and don't contain spaces. If the tag is not generic enough, don't
include it. Aim for 3-5 tags. You must respond in JSON with the key "tags" and the value is list of tags.
-----
-URL: ${url}
-Description: ${description}
+`;
+
+function buildPrompt(
+ bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
+) {
+ if (bookmark.link) {
+ if (!bookmark.link.description) {
+ throw new Error(
+ `No description found for link "${bookmark.id}". Skipping ...`,
+ );
+ }
+ return `
+${PROMPT_BASE}
+---
+URL: ${bookmark.link.url}
+Description: ${bookmark.link.description}
+ `;
+ }
+
+ if (bookmark.text) {
+ // TODO: Ensure that the content doesn't exceed the context length of openai
+ return `
+${PROMPT_BASE}
+---
+Content: ${bookmark.text.text}
`;
+ }
+
+ throw new Error("Unknown bookmark type");
}
async function fetchBookmark(linkId: string) {
@@ -68,26 +86,18 @@ async function fetchBookmark(linkId: string) {
where: eq(bookmarks.id, linkId),
with: {
link: true,
+ text: true,
},
});
}
async function inferTags(
jobId: string,
- link: typeof bookmarkLinks.$inferSelect,
+ bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
openai: OpenAI,
) {
- const linkDescription = link?.description;
- if (!linkDescription) {
- throw new Error(
- `[openai][${jobId}] No description found for link "${link.id}". Skipping ...`,
- );
- }
-
const chatCompletion = await openai.chat.completions.create({
- messages: [
- { role: "system", content: buildPrompt(link.url, linkDescription) },
- ],
+ messages: [{ role: "system", content: buildPrompt(bookmark) }],
model: "gpt-3.5-turbo-0125",
response_format: { type: "json_object" },
});
@@ -100,7 +110,7 @@ async function inferTags(
try {
let tags = openAIResponseSchema.parse(JSON.parse(response)).tags;
logger.info(
- `[openai][${jobId}] Inferring tag for url "${link.url}" used ${chatCompletion.usage?.total_tokens} tokens and inferred: ${tags}`,
+ `[openai][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${chatCompletion.usage?.total_tokens} tokens and inferred: ${tags}`,
);
// Sometimes the tags contain the hashtag symbol, let's strip them out if they do.
@@ -120,7 +130,7 @@ async function inferTags(
}
async function createTags(tags: string[], userId: string) {
- const res = await db
+ await db
.insert(bookmarkTags)
.values(
tags.map((t) => ({
@@ -128,8 +138,18 @@ async function createTags(tags: string[], userId: string) {
userId,
})),
)
- .onConflictDoNothing()
- .returning({ id: bookmarkTags.id });
+ .onConflictDoNothing();
+
+ const res = await db.query.bookmarkTags.findMany({
+ where: and(
+ eq(bookmarkTags.userId, userId),
+ inArray(bookmarkTags.name, tags),
+ ),
+ columns: {
+ id: true,
+ },
+ });
+
return res.map((r) => r.id);
}
@@ -174,13 +194,7 @@ async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
);
}
- if (!bookmark.link) {
- throw new Error(
- `[openai][${jobId}] bookmark with id ${bookmarkId} doesn't have a link`,
- );
- }
-
- const tags = await inferTags(jobId, bookmark.link, openai);
+ const tags = await inferTags(jobId, bookmark, openai);
const tagIds = await createTags(tags, bookmark.userId);
await connectTags(bookmarkId, tagIds);