diff options
| author | MohamedBassem <me@mbassem.com> | 2024-02-07 21:05:57 +0000 |
|---|---|---|
| committer | MohamedBassem <me@mbassem.com> | 2024-02-07 21:05:57 +0000 |
| commit | 8970b3a5375ccfd9b41c8a08722a2fc6bbbe3af9 (patch) | |
| tree | 50e4665944d2fe620522688a10584e29bb0b9e37 /workers/openai.ts | |
| parent | 3ec45e8bbb8285b17c703907d4c161b633663096 (diff) | |
| download | karakeep-8970b3a5375ccfd9b41c8a08722a2fc6bbbe3af9.tar.zst | |
[feature] Add openAI integration for extracting tags from articles
Diffstat (limited to 'workers/openai.ts')
| -rw-r--r-- | workers/openai.ts | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/workers/openai.ts b/workers/openai.ts new file mode 100644 index 00000000..cc23f700 --- /dev/null +++ b/workers/openai.ts @@ -0,0 +1,154 @@ +import prisma, { BookmarkedLink, BookmarkedLinkDetails } from "@remember/db"; +import logger from "@remember/shared/logger"; +import { ZOpenAIRequest, zOpenAIRequestSchema } from "@remember/shared/queues"; +import { Job } from "bullmq"; +import OpenAI from "openai"; +import { z } from "zod"; + +const openAIResponseSchema = z.object({ + tags: z.array(z.string()), +}); + +let openai: OpenAI | undefined; + +if (process.env.OPENAI_API_KEY && process.env.OPENAI_ENABLED) { + openai = new OpenAI({ + apiKey: process.env["OPENAI_API_KEY"], // This is the default and can be omitted + }); +} + +function buildPrompt(url: string, description: string) { + return ` +You are a bot who given an article, extracts relevant "hashtags" out of them. +You must respond in JSON with the key "tags" and the value is list of tags. +---- +URL: ${url} +Description: ${description} + `; +} + +async function fetchLink(linkId: string) { + return await prisma.bookmarkedLink.findUnique({ + where: { + id: linkId, + }, + include: { + details: true, + }, + }); +} + +async function inferTags( + jobId: string, + link: BookmarkedLink, + linkDetails: BookmarkedLinkDetails | null, + openai: OpenAI, +) { + const linkDescription = linkDetails?.description; + if (!linkDescription) { + throw new Error( + `[openai][${jobId}] No description found for link "${link.id}". Skipping ...`, + ); + } + + const chatCompletion = await openai.chat.completions.create({ + messages: [ + { role: "system", content: buildPrompt(link.url, linkDescription) }, + ], + model: "gpt-3.5-turbo-0125", + response_format: { type: "json_object" }, + }); + + let response = chatCompletion.choices[0].message.content; + if (!response) { + throw new Error(`[openai][${jobId}] Got no message content from OpenAI`); + } + + try { + const tags = openAIResponseSchema.parse(JSON.parse(response)).tags; + logger.info( + `[openai][${jobId}] Inferring tag for url "${link.url}" used ${chatCompletion.usage?.total_tokens} tokens and inferred: ${tags}`, + ); + return tags; + } catch (e) { + throw new Error( + `[openai][${jobId}] Failed to parse JSON response from OpenAI: ${e}`, + ); + } +} + +async function createTags(tags: string[], userId: string) { + const existingTags = await prisma.bookmarkTags.findMany({ + select: { + id: true, + name: true, + }, + where: { + userId, + name: { + in: tags, + }, + }, + }); + + const existingTagSet = new Set<string>(existingTags.map((t) => t.name)); + + let newTags = tags.filter((t) => !existingTagSet.has(t)); + + // TODO: Prisma doesn't support createMany in Sqlite + let newTagObjects = await Promise.all( + newTags.map((t) => { + return prisma.bookmarkTags.create({ + data: { + name: t, + userId: userId, + }, + }); + }), + ); + + return existingTags.map((t) => t.id).concat(newTagObjects.map((t) => t.id)); +} + +async function connectTags(linkId: string, tagIds: string[]) { + // TODO: Prisma doesn't support createMany in Sqlite + await Promise.all( + tagIds.map((tagId) => { + return prisma.tagsOnLinks.create({ + data: { + tagId, + linkId, + }, + }); + }), + ); +} + +export default async function runOpenAI(job: Job<ZOpenAIRequest, void>) { + const jobId = job.id || "unknown"; + + if (!openai) { + logger.debug( + `[openai][${jobId}] OpenAI is not configured, nothing to do now`, + ); + return; + } + + const request = zOpenAIRequestSchema.safeParse(job.data); + if (!request.success) { + throw new Error( + `[openai][${jobId}] Got malformed job request: ${request.error.toString()}`, + ); + } + + const { linkId } = request.data; + const link = await fetchLink(linkId); + if (!link) { + throw new Error(`[openai][${jobId}] link with id ${linkId} was not found`); + } + + const tags = await inferTags(jobId, link, link.details, openai); + + const tagIds = await createTags(tags, link.userId); + await connectTags(linkId, tagIds); +} |
