diff options
| author | Mohamed Bassem <me@mbassem.com> | 2026-02-09 00:09:10 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-02-09 00:09:10 +0000 |
| commit | 4186c4c64c68892248ce8671d9b8e67fc7f884a0 (patch) | |
| tree | 91bbbfc0bb47a966b9e340fdbe2a61b2e10ebd19 /apps | |
| parent | 77b186c3a599297da0cf19e923c66607ad7d74e7 (diff) | |
| download | karakeep-4186c4c64c68892248ce8671d9b8e67fc7f884a0.tar.zst | |
feat(ai): Support restricting AI tags to a subset of existing tags (#2444)
* feat(ai): Support restricting AI tags to a subset of existing tags
Co-authored-by: Claude <noreply@anthropic.com>
Diffstat (limited to 'apps')
| -rw-r--r-- | apps/web/components/dashboard/bookmarks/TagsEditor.tsx | 57 | ||||
| -rw-r--r-- | apps/web/components/settings/AISettings.tsx | 160 | ||||
| -rw-r--r-- | apps/web/lib/i18n/locales/en/translation.json | 8 | ||||
| -rw-r--r-- | apps/web/lib/userSettings.tsx | 1 | ||||
| -rw-r--r-- | apps/workers/workers/inference/tagging.ts | 34 |
5 files changed, 251 insertions, 9 deletions
diff --git a/apps/web/components/dashboard/bookmarks/TagsEditor.tsx b/apps/web/components/dashboard/bookmarks/TagsEditor.tsx index 45fae173..ec4a9d8a 100644 --- a/apps/web/components/dashboard/bookmarks/TagsEditor.tsx +++ b/apps/web/components/dashboard/bookmarks/TagsEditor.tsx @@ -13,6 +13,7 @@ import { PopoverTrigger, } from "@/components/ui/popover"; import { useClientConfig } from "@/lib/clientConfig"; +import { useTranslation } from "@/lib/i18n/client"; import { cn } from "@/lib/utils"; import { keepPreviousData, useQuery } from "@tanstack/react-query"; import { Command as CommandPrimitive } from "cmdk"; @@ -26,13 +27,18 @@ export function TagsEditor({ onAttach, onDetach, disabled, + allowCreation = true, + placeholder, }: { tags: ZBookmarkTags[]; onAttach: (tag: { tagName: string; tagId?: string }) => void; onDetach: (tag: { tagName: string; tagId: string }) => void; disabled?: boolean; + allowCreation?: boolean; + placeholder?: string; }) { const api = useTRPC(); + const { t } = useTranslation(); const demoMode = !!useClientConfig().demoMode; const isDisabled = demoMode || disabled; const inputRef = React.useRef<HTMLInputElement>(null); @@ -41,6 +47,7 @@ export function TagsEditor({ const [inputValue, setInputValue] = React.useState(""); const [optimisticTags, setOptimisticTags] = useState<ZBookmarkTags[]>(_tags); const tempIdCounter = React.useRef(0); + const hasInitializedRef = React.useRef(_tags.length > 0); const generateTempId = React.useCallback(() => { tempIdCounter.current += 1; @@ -55,22 +62,39 @@ export function TagsEditor({ }, []); React.useEffect(() => { + // When allowCreation is false, only sync on initial load + // After that, rely on optimistic updates to avoid re-ordering + if (!allowCreation) { + if (!hasInitializedRef.current && _tags.length > 0) { + hasInitializedRef.current = true; + setOptimisticTags(_tags); + } + return; + } + + // For allowCreation mode, sync server state with optimistic state setOptimisticTags((prev) => { - let results = prev; + // Start with a copy to avoid mutating the previous state + const results = [...prev]; + let changed = false; + for (const tag of _tags) { const idx = results.findIndex((t) => t.name === tag.name); if (idx == -1) { results.push(tag); + changed = true; continue; } if (results[idx].id.startsWith("temp-")) { results[idx] = tag; + changed = true; continue; } } - return results; + + return changed ? results : prev; }); - }, [_tags]); + }, [_tags, allowCreation]); const { data: filteredOptions, isLoading: isExistingTagsLoading } = useQuery( api.tags.list.queryOptions( @@ -124,7 +148,7 @@ export function TagsEditor({ (opt) => opt.name.toLowerCase() === trimmedInputValue.toLowerCase(), ); - if (!exactMatch) { + if (!exactMatch && allowCreation) { return [ { id: "create-new", @@ -138,7 +162,7 @@ export function TagsEditor({ } return baseOptions; - }, [filteredOptions, trimmedInputValue]); + }, [filteredOptions, trimmedInputValue, allowCreation]); const onChange = ( actionMeta: @@ -258,6 +282,24 @@ export function TagsEditor({ } }; + const inputPlaceholder = + placeholder ?? + (allowCreation + ? t("tags.search_or_create_placeholder", { + defaultValue: "Search or create tags...", + }) + : t("tags.search_placeholder", { + defaultValue: "Search tags...", + })); + const visiblePlaceholder = + optimisticTags.length === 0 ? inputPlaceholder : undefined; + const inputWidth = Math.max( + inputValue.length > 0 + ? inputValue.length + : Math.min(visiblePlaceholder?.length ?? 1, 24), + 1, + ); + return ( <div ref={containerRef} className="w-full"> <Popover open={open && !isDisabled} onOpenChange={handleOpenChange}> @@ -313,8 +355,9 @@ export function TagsEditor({ value={inputValue} onKeyDown={handleKeyDown} onValueChange={(v) => setInputValue(v)} + placeholder={visiblePlaceholder} className="bg-transparent outline-none placeholder:text-muted-foreground" - style={{ width: `${Math.max(inputValue.length, 1)}ch` }} + style={{ width: `${inputWidth}ch` }} disabled={isDisabled} /> {isExistingTagsLoading && ( @@ -331,7 +374,7 @@ export function TagsEditor({ <CommandList className="max-h-64"> {displayedOptions.length === 0 ? ( <CommandEmpty> - {trimmedInputValue ? ( + {trimmedInputValue && allowCreation ? ( <div className="flex items-center justify-between px-2 py-1.5"> <span>Create "{trimmedInputValue}"</span> <Button diff --git a/apps/web/components/settings/AISettings.tsx b/apps/web/components/settings/AISettings.tsx index 58710fe8..6d28f4f8 100644 --- a/apps/web/components/settings/AISettings.tsx +++ b/apps/web/components/settings/AISettings.tsx @@ -1,5 +1,7 @@ "use client"; +import React from "react"; +import { TagsEditor } from "@/components/dashboard/bookmarks/TagsEditor"; import { ActionButton } from "@/components/ui/action-button"; import { Badge } from "@/components/ui/badge"; import { @@ -48,6 +50,8 @@ import { Info, Plus, Save, Trash2 } from "lucide-react"; import { Controller, useForm } from "react-hook-form"; import { z } from "zod"; +import type { ZBookmarkTags } from "@karakeep/shared/types/tags"; +import { useDebounce } from "@karakeep/shared-react/hooks/use-debounce"; import { useUpdateUserSettings } from "@karakeep/shared-react/hooks/users"; import { useTRPC } from "@karakeep/shared-react/trpc"; import { @@ -340,6 +344,142 @@ export function TagStyleSelector() { ); } +export function CuratedTagsSelector() { + const api = useTRPC(); + const { t } = useTranslation(); + const settings = useUserSettings(); + + const { mutate: updateSettings, isPending: isUpdatingCuratedTags } = + useUpdateUserSettings({ + onSuccess: () => { + toast({ + description: t("settings.ai.curated_tags_updated"), + }); + }, + onError: () => { + toast({ + description: t("settings.ai.curated_tags_update_failed"), + variant: "destructive", + }); + }, + }); + + const areTagIdsEqual = React.useCallback((a: string[], b: string[]) => { + return a.length === b.length && a.every((id, index) => id === b[index]); + }, []); + + const curatedTagIds = React.useMemo( + () => settings?.curatedTagIds ?? [], + [settings?.curatedTagIds], + ); + const [localCuratedTagIds, setLocalCuratedTagIds] = + React.useState<string[]>(curatedTagIds); + const debouncedCuratedTagIds = useDebounce(localCuratedTagIds, 300); + const lastServerCuratedTagIdsRef = React.useRef(curatedTagIds); + const lastSubmittedCuratedTagIdsRef = React.useRef<string[] | null>(null); + + React.useEffect(() => { + const hadUnsyncedLocalChanges = !areTagIdsEqual( + localCuratedTagIds, + lastServerCuratedTagIdsRef.current, + ); + + if ( + !hadUnsyncedLocalChanges && + !areTagIdsEqual(localCuratedTagIds, curatedTagIds) + ) { + setLocalCuratedTagIds(curatedTagIds); + } + + lastServerCuratedTagIdsRef.current = curatedTagIds; + }, [areTagIdsEqual, curatedTagIds, localCuratedTagIds]); + + React.useEffect(() => { + if (isUpdatingCuratedTags) { + return; + } + + if (areTagIdsEqual(debouncedCuratedTagIds, curatedTagIds)) { + lastSubmittedCuratedTagIdsRef.current = null; + return; + } + + if ( + lastSubmittedCuratedTagIdsRef.current && + areTagIdsEqual( + lastSubmittedCuratedTagIdsRef.current, + debouncedCuratedTagIds, + ) + ) { + return; + } + + lastSubmittedCuratedTagIdsRef.current = debouncedCuratedTagIds; + updateSettings({ + curatedTagIds: + debouncedCuratedTagIds.length > 0 ? debouncedCuratedTagIds : null, + }); + }, [ + areTagIdsEqual, + curatedTagIds, + debouncedCuratedTagIds, + isUpdatingCuratedTags, + updateSettings, + ]); + + // Fetch selected tags to display their names + const { data: selectedTagsData } = useQuery( + api.tags.list.queryOptions( + { ids: localCuratedTagIds }, + { enabled: localCuratedTagIds.length > 0 }, + ), + ); + + const selectedTags: ZBookmarkTags[] = React.useMemo(() => { + const tagsMap = new Map( + (selectedTagsData?.tags ?? []).map((tag) => [tag.id, tag]), + ); + // Preserve the order from curatedTagIds instead of server sort order + return localCuratedTagIds + .map((id) => tagsMap.get(id)) + .filter((tag): tag is NonNullable<typeof tag> => tag != null) + .map((tag) => ({ + id: tag.id, + name: tag.name, + attachedBy: "human" as const, + })); + }, [selectedTagsData?.tags, localCuratedTagIds]); + + return ( + <SettingsSection + title={t("settings.ai.curated_tags")} + description={t("settings.ai.curated_tags_description")} + > + <TagsEditor + tags={selectedTags} + placeholder="Select curated tags..." + onAttach={(tag) => { + const tagId = tag.tagId; + if (tagId) { + setLocalCuratedTagIds((prev) => { + if (prev.includes(tagId)) { + return prev; + } + return [...prev, tagId]; + }); + } + }} + onDetach={(tag) => { + setLocalCuratedTagIds((prev) => { + return prev.filter((id) => id !== tag.tagId); + }); + }} + allowCreation={false} + /> + </SettingsSection> + ); +} + export function PromptEditor() { const api = useTRPC(); const { t } = useTranslation(); @@ -617,9 +757,24 @@ export function PromptDemo() { const clientConfig = useClientConfig(); const tagStyle = settings?.tagStyle ?? "as-generated"; + const curatedTagIds = settings?.curatedTagIds ?? []; + const { data: tagsData } = useQuery( + api.tags.list.queryOptions( + { ids: curatedTagIds }, + { enabled: curatedTagIds.length > 0 }, + ), + ); const inferredTagLang = settings?.inferredTagLang ?? clientConfig.inference.inferredTagLang; + // Resolve curated tag names for preview + const curatedTagNames = + curatedTagIds.length > 0 && tagsData?.tags + ? curatedTagIds + .map((id) => tagsData.tags.find((tag) => tag.id === id)?.name) + .filter((name): name is string => Boolean(name)) + : undefined; + return ( <SettingsSection title={t("settings.ai.prompt_preview")} @@ -640,6 +795,7 @@ export function PromptDemo() { .map((p) => p.text), "\n<CONTENT_HERE>\n", tagStyle, + curatedTagNames, ).trim()} </code> </div> @@ -657,6 +813,7 @@ export function PromptDemo() { ) .map((p) => p.text), tagStyle, + curatedTagNames, ).trim()} </code> </div> @@ -693,6 +850,9 @@ export default function AISettings() { {/* Tag Style */} <TagStyleSelector /> + {/* Curated Tags */} + <CuratedTagsSelector /> + {/* Tagging Rules */} <TaggingRules /> diff --git a/apps/web/lib/i18n/locales/en/translation.json b/apps/web/lib/i18n/locales/en/translation.json index 41d5312e..40cf6ece 100644 --- a/apps/web/lib/i18n/locales/en/translation.json +++ b/apps/web/lib/i18n/locales/en/translation.json @@ -268,7 +268,11 @@ "camelCase": "camelCase", "no_preference": "No preference", "inference_language": "Inference Language", - "inference_language_description": "Choose language for AI-generated tags and summaries." + "inference_language_description": "Choose language for AI-generated tags and summaries.", + "curated_tags": "Curated Tags", + "curated_tags_description": "Optionally restrict AI tagging to only use tags from this list. When no tags are selected, the AI generates tags freely.", + "curated_tags_updated": "Curated tags updated successfully!", + "curated_tags_update_failed": "Failed to update curated tags" }, "feeds": { "rss_subscriptions": "RSS Subscriptions", @@ -761,6 +765,8 @@ "create_tag_description": "Create a new tag without attaching it to any bookmark", "tag_name": "Tag Name", "enter_tag_name": "Enter tag name", + "search_placeholder": "Search tags...", + "search_or_create_placeholder": "Search or create tags...", "no_custom_tags": "No custom tags yet", "no_ai_tags": "No AI tags yet", "no_unused_tags": "You don't have any unused tags", diff --git a/apps/web/lib/userSettings.tsx b/apps/web/lib/userSettings.tsx index 41f94cf4..105e258e 100644 --- a/apps/web/lib/userSettings.tsx +++ b/apps/web/lib/userSettings.tsx @@ -19,6 +19,7 @@ export const UserSettingsContext = createContext<ZUserSettings>({ autoTaggingEnabled: null, autoSummarizationEnabled: null, tagStyle: "as-generated", + curatedTagIds: null, inferredTagLang: null, }); diff --git a/apps/workers/workers/inference/tagging.ts b/apps/workers/workers/inference/tagging.ts index b3006193..668c1d5e 100644 --- a/apps/workers/workers/inference/tagging.ts +++ b/apps/workers/workers/inference/tagging.ts @@ -85,6 +85,7 @@ async function buildPrompt( bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, tagStyle: ZTagStyle, inferredTagLang: string, + curatedTags?: string[], ): Promise<string | null> { const prompts = await fetchCustomPrompts(bookmark.userId, "text"); if (bookmark.link) { @@ -110,6 +111,7 @@ Description: ${bookmark.link.description ?? ""} Content: ${content ?? ""}`, serverConfig.inference.contextLength, tagStyle, + curatedTags, ); } @@ -120,6 +122,7 @@ Content: ${content ?? ""}`, bookmark.text.text ?? "", serverConfig.inference.contextLength, tagStyle, + curatedTags, ); } @@ -133,6 +136,7 @@ async function inferTagsFromImage( abortSignal: AbortSignal, tagStyle: ZTagStyle, inferredTagLang: string, + curatedTags?: string[], ): Promise<InferenceResponse | null> { const { asset, metadata } = await readAsset({ userId: bookmark.userId, @@ -160,6 +164,7 @@ async function inferTagsFromImage( inferredTagLang, await fetchCustomPrompts(bookmark.userId, "images"), tagStyle, + curatedTags, ), metadata.contentType, base64, @@ -235,6 +240,7 @@ async function inferTagsFromPDF( abortSignal: AbortSignal, tagStyle: ZTagStyle, inferredTagLang: string, + curatedTags?: string[], ) { const prompt = await buildTextPrompt( inferredTagLang, @@ -242,6 +248,7 @@ async function inferTagsFromPDF( `Content: ${bookmark.asset.content}`, serverConfig.inference.contextLength, tagStyle, + curatedTags, ); setSpanAttributes({ "inference.model": serverConfig.inference.textModel, @@ -261,8 +268,14 @@ async function inferTagsFromText( abortSignal: AbortSignal, tagStyle: ZTagStyle, inferredTagLang: string, + curatedTags?: string[], ) { - const prompt = await buildPrompt(bookmark, tagStyle, inferredTagLang); + const prompt = await buildPrompt( + bookmark, + tagStyle, + inferredTagLang, + curatedTags, + ); if (!prompt) { return null; } @@ -285,6 +298,7 @@ async function inferTags( abortSignal: AbortSignal, tagStyle: ZTagStyle, inferredTagLang: string, + curatedTags?: string[], ) { setSpanAttributes({ "user.id": bookmark.userId, @@ -306,6 +320,7 @@ async function inferTags( abortSignal, tagStyle, inferredTagLang, + curatedTags, ); } else if (bookmark.asset) { switch (bookmark.asset.assetType) { @@ -317,6 +332,7 @@ async function inferTags( abortSignal, tagStyle, inferredTagLang, + curatedTags, ); break; case "pdf": @@ -327,6 +343,7 @@ async function inferTags( abortSignal, tagStyle, inferredTagLang, + curatedTags, ); break; default: @@ -507,6 +524,7 @@ export async function runTagging( columns: { autoTaggingEnabled: true, tagStyle: true, + curatedTagIds: true, inferredTagLang: true, }, }); @@ -518,6 +536,19 @@ export async function runTagging( return; } + // Resolve curated tag names if configured + let curatedTagNames: string[] | undefined; + if (userSettings?.curatedTagIds && userSettings.curatedTagIds.length > 0) { + const tags = await db.query.bookmarkTags.findMany({ + where: and( + eq(bookmarkTags.userId, bookmark.userId), + inArray(bookmarkTags.id, userSettings.curatedTagIds), + ), + columns: { name: true }, + }); + curatedTagNames = tags.map((t) => t.name); + } + logger.info( `[inference][${jobId}] Starting an inference job for bookmark with id "${bookmark.id}"`, ); @@ -529,6 +560,7 @@ export async function runTagging( job.abortSignal, userSettings?.tagStyle ?? "as-generated", userSettings?.inferredTagLang ?? serverConfig.inference.inferredTagLang, + curatedTagNames, ); if (tags === null) { |
