From fdf28ae19ac8d7314bfa6c5d24fdcbabba0aee32 Mon Sep 17 00:00:00 2001 From: kamtschatka Date: Sun, 24 Nov 2024 19:49:17 +0100 Subject: feature: Add support for tag placeholders in custom prompts. #111 (#612) * PR for #111 added a $tags,$aiTags and $userTags placeholder that will be replaced with all tags, ai tags or user tags during inference * Use the new buildImpersonatingTRPCClient util --------- Co-authored-by: Mohamed Bassem --- apps/workers/openaiWorker.ts | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) (limited to 'apps') diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index 948e92a7..4061c7d2 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -1,5 +1,6 @@ import { and, Column, eq, inArray, sql } from "drizzle-orm"; import { DequeuedJob, Runner } from "liteque"; +import { buildImpersonatingTRPCClient } from "trpc"; import { z } from "zod"; import type { InferenceClient } from "@hoarder/shared/inference"; @@ -200,7 +201,47 @@ async function fetchCustomPrompts( }, }); - return prompts.map((p) => p.text); + let promptTexts = prompts.map((p) => p.text); + if (containsTagsPlaceholder(prompts)) { + promptTexts = await replaceTagsPlaceholders(promptTexts, userId); + } + + return promptTexts; +} + +async function replaceTagsPlaceholders( + prompts: string[], + userId: string, +): Promise { + const api = await buildImpersonatingTRPCClient(userId); + const tags = (await api.tags.list()).tags; + const tagsString = `[${tags.map((tag) => tag.name).join(",")}]`; + const aiTagsString = `[${tags + .filter((tag) => tag.numBookmarksByAttachedType.human ?? 0 == 0) + .map((tag) => tag.name) + .join(",")}]`; + const userTagsString = `[${tags + .filter((tag) => tag.numBookmarksByAttachedType.human ?? 0 > 0) + .map((tag) => tag.name) + .join(",")}]`; + + return prompts.map((p) => + p + .replaceAll("$tags", tagsString) + .replaceAll("$aiTags", aiTagsString) + .replaceAll("$userTags", userTagsString), + ); +} + +function containsTagsPlaceholder(prompts: { text: string }[]): boolean { + return ( + prompts.filter( + (p) => + p.text.includes("$tags") || + p.text.includes("$aiTags") || + p.text.includes("$userTags"), + ).length > 0 + ); } async function inferTagsFromPDF( -- cgit v1.2.3-70-g09d2