aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMohamedBassem <me@mbassem.com>2024-02-07 21:05:57 +0000
committerMohamedBassem <me@mbassem.com>2024-02-07 21:05:57 +0000
commit8970b3a5375ccfd9b41c8a08722a2fc6bbbe3af9 (patch)
tree50e4665944d2fe620522688a10584e29bb0b9e37
parent3ec45e8bbb8285b17c703907d4c161b633663096 (diff)
downloadkarakeep-8970b3a5375ccfd9b41c8a08722a2fc6bbbe3af9.tar.zst
[feature] Add openAI integration for extracting tags from articles
-rwxr-xr-xbun.lockbbin262344 -> 270592 bytes
-rw-r--r--db/index.ts4
-rw-r--r--db/prisma/migrations/20240207204211_drop_extra_field_in_tags_links/migration.sql21
-rw-r--r--db/prisma/schema.prisma1
-rw-r--r--shared/queues.ts11
-rw-r--r--workers/crawler.ts6
-rw-r--r--workers/index.ts58
-rw-r--r--workers/openai.ts154
-rw-r--r--workers/package.json3
9 files changed, 239 insertions, 19 deletions
diff --git a/bun.lockb b/bun.lockb
index 4ebd8a9b..4673939e 100755
--- a/bun.lockb
+++ b/bun.lockb
Binary files differ
diff --git a/db/index.ts b/db/index.ts
index dbf925f4..fa46ca1f 100644
--- a/db/index.ts
+++ b/db/index.ts
@@ -2,6 +2,8 @@ import { PrismaClient } from "@prisma/client";
const prisma = new PrismaClient();
-export { Prisma } from "@prisma/client";
+// For some weird reason accessing @prisma/client from any package is causing problems (specially in error handling).
+// Re export them here instead.
+export * from "@prisma/client";
export default prisma;
diff --git a/db/prisma/migrations/20240207204211_drop_extra_field_in_tags_links/migration.sql b/db/prisma/migrations/20240207204211_drop_extra_field_in_tags_links/migration.sql
new file mode 100644
index 00000000..78184041
--- /dev/null
+++ b/db/prisma/migrations/20240207204211_drop_extra_field_in_tags_links/migration.sql
@@ -0,0 +1,21 @@
+/*
+ Warnings:
+
+ - You are about to drop the column `bookmarkTagsId` on the `TagsOnLinks` table. All the data in the column will be lost.
+
+*/
+-- RedefineTables
+PRAGMA foreign_keys=OFF;
+CREATE TABLE "new_TagsOnLinks" (
+ "linkId" TEXT NOT NULL,
+ "tagId" TEXT NOT NULL,
+ "attachedAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ CONSTRAINT "TagsOnLinks_linkId_fkey" FOREIGN KEY ("linkId") REFERENCES "BookmarkedLink" ("id") ON DELETE CASCADE ON UPDATE CASCADE,
+ CONSTRAINT "TagsOnLinks_tagId_fkey" FOREIGN KEY ("tagId") REFERENCES "BookmarkTags" ("id") ON DELETE CASCADE ON UPDATE CASCADE
+);
+INSERT INTO "new_TagsOnLinks" ("attachedAt", "linkId", "tagId") SELECT "attachedAt", "linkId", "tagId" FROM "TagsOnLinks";
+DROP TABLE "TagsOnLinks";
+ALTER TABLE "new_TagsOnLinks" RENAME TO "TagsOnLinks";
+CREATE UNIQUE INDEX "TagsOnLinks_linkId_tagId_key" ON "TagsOnLinks"("linkId", "tagId");
+PRAGMA foreign_key_check;
+PRAGMA foreign_keys=ON;
diff --git a/db/prisma/schema.prisma b/db/prisma/schema.prisma
index f5b83b66..0e6d080c 100644
--- a/db/prisma/schema.prisma
+++ b/db/prisma/schema.prisma
@@ -100,7 +100,6 @@ model TagsOnLinks {
tagId String
attachedAt DateTime @default(now())
- bookmarkTagsId String
@@unique([linkId, tagId])
}
diff --git a/shared/queues.ts b/shared/queues.ts
index ac5acc57..a607131d 100644
--- a/shared/queues.ts
+++ b/shared/queues.ts
@@ -6,6 +6,7 @@ export const queueConnectionDetails = {
port: parseInt(process.env.REDIS_PORT || "6379"),
};
+// Link Crawler
export const zCrawlLinkRequestSchema = z.object({
linkId: z.string(),
url: z.string().url(),
@@ -16,3 +17,13 @@ export const LinkCrawlerQueue = new Queue<ZCrawlLinkRequest, void>(
"link_crawler_queue",
{ connection: queueConnectionDetails },
);
+
+// OpenAI Worker
+export const zOpenAIRequestSchema = z.object({
+ linkId: z.string(),
+});
+export type ZOpenAIRequest = z.infer<typeof zOpenAIRequestSchema>;
+
+export const OpenAIQueue = new Queue<ZOpenAIRequest, void>("openai_queue", {
+ connection: queueConnectionDetails,
+});
diff --git a/workers/crawler.ts b/workers/crawler.ts
index c0f433af..817bba45 100644
--- a/workers/crawler.ts
+++ b/workers/crawler.ts
@@ -1,5 +1,6 @@
import logger from "@remember/shared/logger";
import {
+ OpenAIQueue,
ZCrawlLinkRequest,
zCrawlLinkRequestSchema,
} from "@remember/shared/queues";
@@ -69,4 +70,9 @@ export default async function runCrawler(job: Job<ZCrawlLinkRequest, void>) {
details: true,
},
});
+
+ // Enqueue openai job
+ OpenAIQueue.add("openai", {
+ linkId,
+ });
}
diff --git a/workers/index.ts b/workers/index.ts
index 76c6f03f..bf092953 100644
--- a/workers/index.ts
+++ b/workers/index.ts
@@ -2,31 +2,57 @@ import { Worker } from "bullmq";
import {
LinkCrawlerQueue,
+ OpenAIQueue,
ZCrawlLinkRequest,
+ ZOpenAIRequest,
queueConnectionDetails,
} from "@remember/shared/queues";
import logger from "@remember/shared/logger";
import runCrawler from "./crawler";
+import runOpenAI from "./openai";
-logger.info("Starting crawler worker ...");
+function crawlerWorker() {
+ logger.info("Starting crawler worker ...");
+ const worker = new Worker<ZCrawlLinkRequest, void>(
+ LinkCrawlerQueue.name,
+ runCrawler,
+ {
+ connection: queueConnectionDetails,
+ autorun: false,
+ },
+ );
-const crawlerWorker = new Worker<ZCrawlLinkRequest, void>(
- LinkCrawlerQueue.name,
- runCrawler,
- {
+ worker.on("completed", (job) => {
+ const jobId = job?.id || "unknown";
+ logger.info(`[Crawler][${jobId}] Completed successfully`);
+ });
+
+ worker.on("failed", (job, error) => {
+ const jobId = job?.id || "unknown";
+ logger.error(`[Crawler][${jobId}] Crawling job failed: ${error}`);
+ });
+
+ return worker;
+}
+
+function openaiWorker() {
+ logger.info("Starting openai worker ...");
+ const worker = new Worker<ZOpenAIRequest, void>(OpenAIQueue.name, runOpenAI, {
connection: queueConnectionDetails,
autorun: false,
- },
-);
+ });
+
+ worker.on("completed", (job) => {
+ const jobId = job?.id || "unknown";
+ logger.info(`[openai][${jobId}] Completed successfully`);
+ });
-crawlerWorker.on("completed", (job) => {
- const jobId = job?.id || "unknown";
- logger.info(`[Crawler][${jobId}] Completed successfully`);
-});
+ worker.on("failed", (job, error) => {
+ const jobId = job?.id || "unknown";
+ logger.error(`[openai][${jobId}] openai job failed: ${error}`);
+ });
-crawlerWorker.on("failed", (job, error) => {
- const jobId = job?.id || "unknown";
- logger.error(`[Crawler][${jobId}] Crawling job failed: ${error}`);
-});
+ return worker;
+}
-await Promise.all([crawlerWorker.run()]);
+await Promise.all([crawlerWorker().run(), openaiWorker().run()]);
diff --git a/workers/openai.ts b/workers/openai.ts
new file mode 100644
index 00000000..cc23f700
--- /dev/null
+++ b/workers/openai.ts
@@ -0,0 +1,154 @@
+import prisma, { BookmarkedLink, BookmarkedLinkDetails } from "@remember/db";
+import logger from "@remember/shared/logger";
+import { ZOpenAIRequest, zOpenAIRequestSchema } from "@remember/shared/queues";
+import { Job } from "bullmq";
+import OpenAI from "openai";
+import { z } from "zod";
+
+const openAIResponseSchema = z.object({
+ tags: z.array(z.string()),
+});
+
+let openai: OpenAI | undefined;
+
+if (process.env.OPENAI_API_KEY && process.env.OPENAI_ENABLED) {
+ openai = new OpenAI({
+ apiKey: process.env["OPENAI_API_KEY"], // This is the default and can be omitted
+ });
+}
+
+function buildPrompt(url: string, description: string) {
+ return `
+You are a bot who given an article, extracts relevant "hashtags" out of them.
+You must respond in JSON with the key "tags" and the value is list of tags.
+----
+URL: ${url}
+Description: ${description}
+ `;
+}
+
+async function fetchLink(linkId: string) {
+ return await prisma.bookmarkedLink.findUnique({
+ where: {
+ id: linkId,
+ },
+ include: {
+ details: true,
+ },
+ });
+}
+
+async function inferTags(
+ jobId: string,
+ link: BookmarkedLink,
+ linkDetails: BookmarkedLinkDetails | null,
+ openai: OpenAI,
+) {
+ const linkDescription = linkDetails?.description;
+ if (!linkDescription) {
+ throw new Error(
+ `[openai][${jobId}] No description found for link "${link.id}". Skipping ...`,
+ );
+ }
+
+ const chatCompletion = await openai.chat.completions.create({
+ messages: [
+ { role: "system", content: buildPrompt(link.url, linkDescription) },
+ ],
+ model: "gpt-3.5-turbo-0125",
+ response_format: { type: "json_object" },
+ });
+
+ let response = chatCompletion.choices[0].message.content;
+ if (!response) {
+ throw new Error(`[openai][${jobId}] Got no message content from OpenAI`);
+ }
+
+ try {
+ const tags = openAIResponseSchema.parse(JSON.parse(response)).tags;
+ logger.info(
+ `[openai][${jobId}] Inferring tag for url "${link.url}" used ${chatCompletion.usage?.total_tokens} tokens and inferred: ${tags}`,
+ );
+ return tags;
+ } catch (e) {
+ throw new Error(
+ `[openai][${jobId}] Failed to parse JSON response from OpenAI: ${e}`,
+ );
+ }
+}
+
+async function createTags(tags: string[], userId: string) {
+ const existingTags = await prisma.bookmarkTags.findMany({
+ select: {
+ id: true,
+ name: true,
+ },
+ where: {
+ userId,
+ name: {
+ in: tags,
+ },
+ },
+ });
+
+ const existingTagSet = new Set<string>(existingTags.map((t) => t.name));
+
+ let newTags = tags.filter((t) => !existingTagSet.has(t));
+
+ // TODO: Prisma doesn't support createMany in Sqlite
+ let newTagObjects = await Promise.all(
+ newTags.map((t) => {
+ return prisma.bookmarkTags.create({
+ data: {
+ name: t,
+ userId: userId,
+ },
+ });
+ }),
+ );
+
+ return existingTags.map((t) => t.id).concat(newTagObjects.map((t) => t.id));
+}
+
+async function connectTags(linkId: string, tagIds: string[]) {
+ // TODO: Prisma doesn't support createMany in Sqlite
+ await Promise.all(
+ tagIds.map((tagId) => {
+ return prisma.tagsOnLinks.create({
+ data: {
+ tagId,
+ linkId,
+ },
+ });
+ }),
+ );
+}
+
+export default async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
+ const jobId = job.id || "unknown";
+
+ if (!openai) {
+ logger.debug(
+ `[openai][${jobId}] OpenAI is not configured, nothing to do now`,
+ );
+ return;
+ }
+
+ const request = zOpenAIRequestSchema.safeParse(job.data);
+ if (!request.success) {
+ throw new Error(
+ `[openai][${jobId}] Got malformed job request: ${request.error.toString()}`,
+ );
+ }
+
+ const { linkId } = request.data;
+ const link = await fetchLink(linkId);
+ if (!link) {
+ throw new Error(`[openai][${jobId}] link with id ${linkId} was not found`);
+ }
+
+ const tags = await inferTags(jobId, link, link.details, openai);
+
+ const tagIds = await createTags(tags, link.userId);
+ await connectTags(linkId, tagIds);
+}
diff --git a/workers/package.json b/workers/package.json
index 950233ab..1d32f499 100644
--- a/workers/package.json
+++ b/workers/package.json
@@ -11,7 +11,8 @@
"metascraper-logo": "^5.43.4",
"metascraper-title": "^5.43.4",
"metascraper-url": "^5.43.4",
- "metascraper-logo-favicon": "^5.43.4"
+ "metascraper-logo-favicon": "^5.43.4",
+ "openai": "^4.26.1"
},
"devDependencies": {
"@types/metascraper": "^5.14.3"