diff options
| author | Mohamed Bassem <me@mbassem.com> | 2024-03-19 00:33:11 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-19 00:33:11 +0000 |
| commit | 785a5b574992296e187a66412dd42f7b4a686353 (patch) | |
| tree | 64b608927cc63d7494395f639636fd4b36e5a977 /apps/workers/openaiWorker.ts | |
| parent | 549520919c482e72cdf7adae5ba852d1b6cbe5aa (diff) | |
| download | karakeep-785a5b574992296e187a66412dd42f7b4a686353.tar.zst | |
Feature: Add support for uploading images and automatically inferring their tags (#2)
* feature: Experimental support for asset uploads
* feature(web): Add new bookmark type asset
* feature: Add support for automatically tagging images
* fix: Add support for image assets in preview page
* use next Image for fetching the images
* Fix auth and error codes in the route handlers
* Add support for image uploads on mobile
* Fix typing of upload requests
* Remove the ugly dragging box
* Bump mobile version to 1.3
* Change the editor card placeholder to mention uploading images
* Fix a typo
* Change ios icon for photo library
* Silence typescript error
Diffstat (limited to 'apps/workers/openaiWorker.ts')
| -rw-r--r-- | apps/workers/openaiWorker.ts | 104 |
1 files changed, 87 insertions, 17 deletions
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index 1ec22d32..428f6027 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -1,19 +1,21 @@ +import { Job, Worker } from "bullmq"; +import { and, eq, inArray } from "drizzle-orm"; +import OpenAI from "openai"; +import { z } from "zod"; + +import Base64 from "js-base64"; + import { db } from "@hoarder/db"; -import logger from "@hoarder/shared/logger"; +import { assets, bookmarks, bookmarkTags, tagsOnBookmarks } from "@hoarder/db/schema"; import serverConfig from "@hoarder/shared/config"; +import logger from "@hoarder/shared/logger"; import { OpenAIQueue, + queueConnectionDetails, SearchIndexingQueue, ZOpenAIRequest, - queueConnectionDetails, zOpenAIRequestSchema, } from "@hoarder/shared/queues"; -import { Job } from "bullmq"; -import OpenAI from "openai"; -import { z } from "zod"; -import { Worker } from "bullmq"; -import { bookmarkTags, bookmarks, tagsOnBookmarks } from "@hoarder/db/schema"; -import { and, eq, inArray } from "drizzle-orm"; const openAIResponseSchema = z.object({ tags: z.array(z.string()), @@ -67,7 +69,15 @@ export class OpenAiWorker { } } -const PROMPT_BASE = ` +const IMAGE_PROMPT_BASE = ` +I'm building a read-it-later app and I need your help with automatic tagging. +Please analyze the attached image and suggest relevant tags that describe its key themes, topics, and main ideas. +Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres. If it's a famous website +you may also include a tag for the website. If the tag is not generic enough, don't include it. Aim for 10-15 tags. +If there are no good tags, don't emit any. You must respond in valid JSON with the key "tags" and the value is list of tags. +Don't wrap the response in a markdown code.`; + +const TEXT_PROMPT_BASE = ` I'm building a read-it-later app and I need your help with automatic tagging. Please analyze the text after the sentence "CONTENT START HERE:" and suggest relevant tags that describe its key themes, topics, and main ideas. Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres. If it's a famous website @@ -96,18 +106,18 @@ function buildPrompt( } } return ` -${PROMPT_BASE} +${TEXT_PROMPT_BASE} URL: ${bookmark.link.url} -Title: ${bookmark.link.title || ""} -Description: ${bookmark.link.description || ""} -Content: ${content || ""} +Title: ${bookmark.link.title ?? ""} +Description: ${bookmark.link.description ?? ""} +Content: ${content ?? ""} `; } if (bookmark.text) { // TODO: Ensure that the content doesn't exceed the context length of openai return ` -${PROMPT_BASE} +${TEXT_PROMPT_BASE} ${bookmark.text.text} `; } @@ -121,11 +131,55 @@ async function fetchBookmark(linkId: string) { with: { link: true, text: true, + asset: true, }, }); } -async function inferTags( +async function inferTagsFromImage( + jobId: string, + bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, + openai: OpenAI, +) { + + const asset = await db.query.assets.findFirst({ + where: eq(assets.id, bookmark.asset.assetId), + }); + + if (!asset) { + throw new Error(`[openai][${jobId}] AssetId ${bookmark.asset.assetId} for bookmark ${bookmark.id} not found`); + } + + const base64 = Base64.encode(asset.blob as string); + + const chatCompletion = await openai.chat.completions.create({ + model: "gpt-4-vision-preview", + messages: [ + { + role: "user", + content: [ + { type: "text", text: IMAGE_PROMPT_BASE }, + { + type: "image_url", + image_url: { + url: `data:image/jpeg;base64,${base64}`, + detail: "low", + }, + }, + ], + }, + ], + max_tokens: 2000, + }); + + const response = chatCompletion.choices[0].message.content; + if (!response) { + throw new Error(`[openai][${jobId}] Got no message content from OpenAI`); + } + return {response, totalTokens: chatCompletion.usage?.total_tokens}; +} + +async function inferTagsFromText( jobId: string, bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, openai: OpenAI, @@ -140,11 +194,27 @@ async function inferTags( if (!response) { throw new Error(`[openai][${jobId}] Got no message content from OpenAI`); } + return {response, totalTokens: chatCompletion.usage?.total_tokens}; +} + +async function inferTags( + jobId: string, + bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, + openai: OpenAI, +) { + let response; + if (bookmark.link || bookmark.text) { + response = await inferTagsFromText(jobId, bookmark, openai); + } else if (bookmark.asset) { + response = await inferTagsFromImage(jobId, bookmark, openai); + } else { + throw new Error(`[openai][${jobId}] Unsupported bookmark type`); + } try { - let tags = openAIResponseSchema.parse(JSON.parse(response)).tags; + let tags = openAIResponseSchema.parse(JSON.parse(response.response)).tags; logger.info( - `[openai][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${chatCompletion.usage?.total_tokens} tokens and inferred: ${tags}`, + `[openai][${jobId}] Inferring tag for bookmark "${bookmark.id}" used ${response.totalTokens} tokens and inferred: ${tags}`, ); // Sometimes the tags contain the hashtag symbol, let's strip them out if they do. |
