From 4186c4c64c68892248ce8671d9b8e67fc7f884a0 Mon Sep 17 00:00:00 2001 From: Mohamed Bassem Date: Mon, 9 Feb 2026 00:09:10 +0000 Subject: feat(ai): Support restricting AI tags to a subset of existing tags (#2444) * feat(ai): Support restricting AI tags to a subset of existing tags Co-authored-by: Claude --- apps/web/components/settings/AISettings.tsx | 160 ++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) (limited to 'apps/web/components/settings') diff --git a/apps/web/components/settings/AISettings.tsx b/apps/web/components/settings/AISettings.tsx index 58710fe8..6d28f4f8 100644 --- a/apps/web/components/settings/AISettings.tsx +++ b/apps/web/components/settings/AISettings.tsx @@ -1,5 +1,7 @@ "use client"; +import React from "react"; +import { TagsEditor } from "@/components/dashboard/bookmarks/TagsEditor"; import { ActionButton } from "@/components/ui/action-button"; import { Badge } from "@/components/ui/badge"; import { @@ -48,6 +50,8 @@ import { Info, Plus, Save, Trash2 } from "lucide-react"; import { Controller, useForm } from "react-hook-form"; import { z } from "zod"; +import type { ZBookmarkTags } from "@karakeep/shared/types/tags"; +import { useDebounce } from "@karakeep/shared-react/hooks/use-debounce"; import { useUpdateUserSettings } from "@karakeep/shared-react/hooks/users"; import { useTRPC } from "@karakeep/shared-react/trpc"; import { @@ -340,6 +344,142 @@ export function TagStyleSelector() { ); } +export function CuratedTagsSelector() { + const api = useTRPC(); + const { t } = useTranslation(); + const settings = useUserSettings(); + + const { mutate: updateSettings, isPending: isUpdatingCuratedTags } = + useUpdateUserSettings({ + onSuccess: () => { + toast({ + description: t("settings.ai.curated_tags_updated"), + }); + }, + onError: () => { + toast({ + description: t("settings.ai.curated_tags_update_failed"), + variant: "destructive", + }); + }, + }); + + const areTagIdsEqual = React.useCallback((a: string[], b: string[]) => { + return a.length === b.length && a.every((id, index) => id === b[index]); + }, []); + + const curatedTagIds = React.useMemo( + () => settings?.curatedTagIds ?? [], + [settings?.curatedTagIds], + ); + const [localCuratedTagIds, setLocalCuratedTagIds] = + React.useState(curatedTagIds); + const debouncedCuratedTagIds = useDebounce(localCuratedTagIds, 300); + const lastServerCuratedTagIdsRef = React.useRef(curatedTagIds); + const lastSubmittedCuratedTagIdsRef = React.useRef(null); + + React.useEffect(() => { + const hadUnsyncedLocalChanges = !areTagIdsEqual( + localCuratedTagIds, + lastServerCuratedTagIdsRef.current, + ); + + if ( + !hadUnsyncedLocalChanges && + !areTagIdsEqual(localCuratedTagIds, curatedTagIds) + ) { + setLocalCuratedTagIds(curatedTagIds); + } + + lastServerCuratedTagIdsRef.current = curatedTagIds; + }, [areTagIdsEqual, curatedTagIds, localCuratedTagIds]); + + React.useEffect(() => { + if (isUpdatingCuratedTags) { + return; + } + + if (areTagIdsEqual(debouncedCuratedTagIds, curatedTagIds)) { + lastSubmittedCuratedTagIdsRef.current = null; + return; + } + + if ( + lastSubmittedCuratedTagIdsRef.current && + areTagIdsEqual( + lastSubmittedCuratedTagIdsRef.current, + debouncedCuratedTagIds, + ) + ) { + return; + } + + lastSubmittedCuratedTagIdsRef.current = debouncedCuratedTagIds; + updateSettings({ + curatedTagIds: + debouncedCuratedTagIds.length > 0 ? debouncedCuratedTagIds : null, + }); + }, [ + areTagIdsEqual, + curatedTagIds, + debouncedCuratedTagIds, + isUpdatingCuratedTags, + updateSettings, + ]); + + // Fetch selected tags to display their names + const { data: selectedTagsData } = useQuery( + api.tags.list.queryOptions( + { ids: localCuratedTagIds }, + { enabled: localCuratedTagIds.length > 0 }, + ), + ); + + const selectedTags: ZBookmarkTags[] = React.useMemo(() => { + const tagsMap = new Map( + (selectedTagsData?.tags ?? []).map((tag) => [tag.id, tag]), + ); + // Preserve the order from curatedTagIds instead of server sort order + return localCuratedTagIds + .map((id) => tagsMap.get(id)) + .filter((tag): tag is NonNullable => tag != null) + .map((tag) => ({ + id: tag.id, + name: tag.name, + attachedBy: "human" as const, + })); + }, [selectedTagsData?.tags, localCuratedTagIds]); + + return ( + + { + const tagId = tag.tagId; + if (tagId) { + setLocalCuratedTagIds((prev) => { + if (prev.includes(tagId)) { + return prev; + } + return [...prev, tagId]; + }); + } + }} + onDetach={(tag) => { + setLocalCuratedTagIds((prev) => { + return prev.filter((id) => id !== tag.tagId); + }); + }} + allowCreation={false} + /> + + ); +} + export function PromptEditor() { const api = useTRPC(); const { t } = useTranslation(); @@ -617,9 +757,24 @@ export function PromptDemo() { const clientConfig = useClientConfig(); const tagStyle = settings?.tagStyle ?? "as-generated"; + const curatedTagIds = settings?.curatedTagIds ?? []; + const { data: tagsData } = useQuery( + api.tags.list.queryOptions( + { ids: curatedTagIds }, + { enabled: curatedTagIds.length > 0 }, + ), + ); const inferredTagLang = settings?.inferredTagLang ?? clientConfig.inference.inferredTagLang; + // Resolve curated tag names for preview + const curatedTagNames = + curatedTagIds.length > 0 && tagsData?.tags + ? curatedTagIds + .map((id) => tagsData.tags.find((tag) => tag.id === id)?.name) + .filter((name): name is string => Boolean(name)) + : undefined; + return ( p.text), "\n\n", tagStyle, + curatedTagNames, ).trim()} @@ -657,6 +813,7 @@ export function PromptDemo() { ) .map((p) => p.text), tagStyle, + curatedTagNames, ).trim()} @@ -693,6 +850,9 @@ export default function AISettings() { {/* Tag Style */} + {/* Curated Tags */} + + {/* Tagging Rules */} -- cgit v1.2.3-70-g09d2