diff options
Diffstat (limited to 'apps')
| -rw-r--r-- | apps/workers/openaiWorker.ts | 43 |
1 files changed, 42 insertions, 1 deletions
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index 948e92a7..4061c7d2 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -1,5 +1,6 @@ import { and, Column, eq, inArray, sql } from "drizzle-orm"; import { DequeuedJob, Runner } from "liteque"; +import { buildImpersonatingTRPCClient } from "trpc"; import { z } from "zod"; import type { InferenceClient } from "@hoarder/shared/inference"; @@ -200,7 +201,47 @@ async function fetchCustomPrompts( }, }); - return prompts.map((p) => p.text); + let promptTexts = prompts.map((p) => p.text); + if (containsTagsPlaceholder(prompts)) { + promptTexts = await replaceTagsPlaceholders(promptTexts, userId); + } + + return promptTexts; +} + +async function replaceTagsPlaceholders( + prompts: string[], + userId: string, +): Promise<string[]> { + const api = await buildImpersonatingTRPCClient(userId); + const tags = (await api.tags.list()).tags; + const tagsString = `[${tags.map((tag) => tag.name).join(",")}]`; + const aiTagsString = `[${tags + .filter((tag) => tag.numBookmarksByAttachedType.human ?? 0 == 0) + .map((tag) => tag.name) + .join(",")}]`; + const userTagsString = `[${tags + .filter((tag) => tag.numBookmarksByAttachedType.human ?? 0 > 0) + .map((tag) => tag.name) + .join(",")}]`; + + return prompts.map((p) => + p + .replaceAll("$tags", tagsString) + .replaceAll("$aiTags", aiTagsString) + .replaceAll("$userTags", userTagsString), + ); +} + +function containsTagsPlaceholder(prompts: { text: string }[]): boolean { + return ( + prompts.filter( + (p) => + p.text.includes("$tags") || + p.text.includes("$aiTags") || + p.text.includes("$userTags"), + ).length > 0 + ); } async function inferTagsFromPDF( |
