aboutsummaryrefslogtreecommitdiffstats
path: root/workers/openai.ts
diff options
context:
space:
mode:
authorMohamedBassem <me@mbassem.com>2024-02-07 21:05:57 +0000
committerMohamedBassem <me@mbassem.com>2024-02-07 21:05:57 +0000
commit8970b3a5375ccfd9b41c8a08722a2fc6bbbe3af9 (patch)
tree50e4665944d2fe620522688a10584e29bb0b9e37 /workers/openai.ts
parent3ec45e8bbb8285b17c703907d4c161b633663096 (diff)
downloadkarakeep-8970b3a5375ccfd9b41c8a08722a2fc6bbbe3af9.tar.zst
[feature] Add openAI integration for extracting tags from articles
Diffstat (limited to 'workers/openai.ts')
-rw-r--r--workers/openai.ts154
1 files changed, 154 insertions, 0 deletions
diff --git a/workers/openai.ts b/workers/openai.ts
new file mode 100644
index 00000000..cc23f700
--- /dev/null
+++ b/workers/openai.ts
@@ -0,0 +1,154 @@
+import prisma, { BookmarkedLink, BookmarkedLinkDetails } from "@remember/db";
+import logger from "@remember/shared/logger";
+import { ZOpenAIRequest, zOpenAIRequestSchema } from "@remember/shared/queues";
+import { Job } from "bullmq";
+import OpenAI from "openai";
+import { z } from "zod";
+
+const openAIResponseSchema = z.object({
+ tags: z.array(z.string()),
+});
+
+let openai: OpenAI | undefined;
+
+if (process.env.OPENAI_API_KEY && process.env.OPENAI_ENABLED) {
+ openai = new OpenAI({
+ apiKey: process.env["OPENAI_API_KEY"], // This is the default and can be omitted
+ });
+}
+
+function buildPrompt(url: string, description: string) {
+ return `
+You are a bot who given an article, extracts relevant "hashtags" out of them.
+You must respond in JSON with the key "tags" and the value is list of tags.
+----
+URL: ${url}
+Description: ${description}
+ `;
+}
+
+async function fetchLink(linkId: string) {
+ return await prisma.bookmarkedLink.findUnique({
+ where: {
+ id: linkId,
+ },
+ include: {
+ details: true,
+ },
+ });
+}
+
+async function inferTags(
+ jobId: string,
+ link: BookmarkedLink,
+ linkDetails: BookmarkedLinkDetails | null,
+ openai: OpenAI,
+) {
+ const linkDescription = linkDetails?.description;
+ if (!linkDescription) {
+ throw new Error(
+ `[openai][${jobId}] No description found for link "${link.id}". Skipping ...`,
+ );
+ }
+
+ const chatCompletion = await openai.chat.completions.create({
+ messages: [
+ { role: "system", content: buildPrompt(link.url, linkDescription) },
+ ],
+ model: "gpt-3.5-turbo-0125",
+ response_format: { type: "json_object" },
+ });
+
+ let response = chatCompletion.choices[0].message.content;
+ if (!response) {
+ throw new Error(`[openai][${jobId}] Got no message content from OpenAI`);
+ }
+
+ try {
+ const tags = openAIResponseSchema.parse(JSON.parse(response)).tags;
+ logger.info(
+ `[openai][${jobId}] Inferring tag for url "${link.url}" used ${chatCompletion.usage?.total_tokens} tokens and inferred: ${tags}`,
+ );
+ return tags;
+ } catch (e) {
+ throw new Error(
+ `[openai][${jobId}] Failed to parse JSON response from OpenAI: ${e}`,
+ );
+ }
+}
+
+async function createTags(tags: string[], userId: string) {
+ const existingTags = await prisma.bookmarkTags.findMany({
+ select: {
+ id: true,
+ name: true,
+ },
+ where: {
+ userId,
+ name: {
+ in: tags,
+ },
+ },
+ });
+
+ const existingTagSet = new Set<string>(existingTags.map((t) => t.name));
+
+ let newTags = tags.filter((t) => !existingTagSet.has(t));
+
+ // TODO: Prisma doesn't support createMany in Sqlite
+ let newTagObjects = await Promise.all(
+ newTags.map((t) => {
+ return prisma.bookmarkTags.create({
+ data: {
+ name: t,
+ userId: userId,
+ },
+ });
+ }),
+ );
+
+ return existingTags.map((t) => t.id).concat(newTagObjects.map((t) => t.id));
+}
+
+async function connectTags(linkId: string, tagIds: string[]) {
+ // TODO: Prisma doesn't support createMany in Sqlite
+ await Promise.all(
+ tagIds.map((tagId) => {
+ return prisma.tagsOnLinks.create({
+ data: {
+ tagId,
+ linkId,
+ },
+ });
+ }),
+ );
+}
+
+export default async function runOpenAI(job: Job<ZOpenAIRequest, void>) {
+ const jobId = job.id || "unknown";
+
+ if (!openai) {
+ logger.debug(
+ `[openai][${jobId}] OpenAI is not configured, nothing to do now`,
+ );
+ return;
+ }
+
+ const request = zOpenAIRequestSchema.safeParse(job.data);
+ if (!request.success) {
+ throw new Error(
+ `[openai][${jobId}] Got malformed job request: ${request.error.toString()}`,
+ );
+ }
+
+ const { linkId } = request.data;
+ const link = await fetchLink(linkId);
+ if (!link) {
+ throw new Error(`[openai][${jobId}] link with id ${linkId} was not found`);
+ }
+
+ const tags = await inferTags(jobId, link, link.details, openai);
+
+ const tagIds = await createTags(tags, link.userId);
+ await connectTags(linkId, tagIds);
+}