aboutsummaryrefslogtreecommitdiffstats
path: root/apps/workers/openaiWorker.ts
diff options
context:
space:
mode:
Diffstat (limited to 'apps/workers/openaiWorker.ts')
-rw-r--r--apps/workers/openaiWorker.ts43
1 files changed, 42 insertions, 1 deletions
diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts
index 948e92a7..4061c7d2 100644
--- a/apps/workers/openaiWorker.ts
+++ b/apps/workers/openaiWorker.ts
@@ -1,5 +1,6 @@
import { and, Column, eq, inArray, sql } from "drizzle-orm";
import { DequeuedJob, Runner } from "liteque";
+import { buildImpersonatingTRPCClient } from "trpc";
import { z } from "zod";
import type { InferenceClient } from "@hoarder/shared/inference";
@@ -200,7 +201,47 @@ async function fetchCustomPrompts(
},
});
- return prompts.map((p) => p.text);
+ let promptTexts = prompts.map((p) => p.text);
+ if (containsTagsPlaceholder(prompts)) {
+ promptTexts = await replaceTagsPlaceholders(promptTexts, userId);
+ }
+
+ return promptTexts;
+}
+
+async function replaceTagsPlaceholders(
+ prompts: string[],
+ userId: string,
+): Promise<string[]> {
+ const api = await buildImpersonatingTRPCClient(userId);
+ const tags = (await api.tags.list()).tags;
+ const tagsString = `[${tags.map((tag) => tag.name).join(",")}]`;
+ const aiTagsString = `[${tags
+ .filter((tag) => tag.numBookmarksByAttachedType.human ?? 0 == 0)
+ .map((tag) => tag.name)
+ .join(",")}]`;
+ const userTagsString = `[${tags
+ .filter((tag) => tag.numBookmarksByAttachedType.human ?? 0 > 0)
+ .map((tag) => tag.name)
+ .join(",")}]`;
+
+ return prompts.map((p) =>
+ p
+ .replaceAll("$tags", tagsString)
+ .replaceAll("$aiTags", aiTagsString)
+ .replaceAll("$userTags", userTagsString),
+ );
+}
+
+function containsTagsPlaceholder(prompts: { text: string }[]): boolean {
+ return (
+ prompts.filter(
+ (p) =>
+ p.text.includes("$tags") ||
+ p.text.includes("$aiTags") ||
+ p.text.includes("$userTags"),
+ ).length > 0
+ );
}
async function inferTagsFromPDF(