diff options
Diffstat (limited to 'apps/web/components/settings/AISettings.tsx')
| -rw-r--r-- | apps/web/components/settings/AISettings.tsx | 669 |
1 files changed, 584 insertions, 85 deletions
diff --git a/apps/web/components/settings/AISettings.tsx b/apps/web/components/settings/AISettings.tsx index beaa93dc..6d28f4f8 100644 --- a/apps/web/components/settings/AISettings.tsx +++ b/apps/web/components/settings/AISettings.tsx @@ -1,6 +1,25 @@ "use client"; +import React from "react"; +import { TagsEditor } from "@/components/dashboard/bookmarks/TagsEditor"; import { ActionButton } from "@/components/ui/action-button"; +import { Badge } from "@/components/ui/badge"; +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Field, + FieldContent, + FieldDescription, + FieldError, + FieldGroup, + FieldLabel, + FieldTitle, +} from "@/components/ui/field"; import { Form, FormControl, @@ -10,6 +29,7 @@ import { } from "@/components/ui/form"; import { FullPageSpinner } from "@/components/ui/full-page-spinner"; import { Input } from "@/components/ui/input"; +import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; import { Select, SelectContent, @@ -18,15 +38,22 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { toast } from "@/components/ui/use-toast"; +import { toast } from "@/components/ui/sonner"; +import { Switch } from "@/components/ui/switch"; import { useClientConfig } from "@/lib/clientConfig"; import { useTranslation } from "@/lib/i18n/client"; -import { api } from "@/lib/trpc"; +import { useUserSettings } from "@/lib/userSettings"; +import { cn } from "@/lib/utils"; import { zodResolver } from "@hookform/resolvers/zod"; -import { Plus, Save, Trash2 } from "lucide-react"; -import { useForm } from "react-hook-form"; +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { Info, Plus, Save, Trash2 } from "lucide-react"; +import { Controller, useForm } from "react-hook-form"; import { z } from "zod"; +import type { ZBookmarkTags } from "@karakeep/shared/types/tags"; +import { useDebounce } from "@karakeep/shared-react/hooks/use-debounce"; +import { useUpdateUserSettings } from "@karakeep/shared-react/hooks/users"; +import { useTRPC } from "@karakeep/shared-react/trpc"; import { buildImagePrompt, buildSummaryPromptUntruncated, @@ -37,10 +64,426 @@ import { ZPrompt, zUpdatePromptSchema, } from "@karakeep/shared/types/prompts"; +import { zUpdateUserSettingsSchema } from "@karakeep/shared/types/users"; + +function SettingsSection({ + title, + description, + children, +}: { + title?: string; + description?: string; + children: React.ReactNode; + className?: string; +}) { + return ( + <Card> + <CardHeader> + {title && <CardTitle>{title}</CardTitle>} + {description && <CardDescription>{description}</CardDescription>} + </CardHeader> + <CardContent>{children}</CardContent> + </Card> + ); +} + +export function AIPreferences() { + const { t } = useTranslation(); + const clientConfig = useClientConfig(); + const settings = useUserSettings(); + + const { mutate: updateSettings, isPending } = useUpdateUserSettings({ + onSuccess: () => { + toast({ + description: "Settings updated successfully!", + }); + }, + onError: () => { + toast({ + description: "Failed to update settings", + variant: "destructive", + }); + }, + }); + + const form = useForm<z.infer<typeof zUpdateUserSettingsSchema>>({ + resolver: zodResolver(zUpdateUserSettingsSchema), + values: settings + ? { + inferredTagLang: settings.inferredTagLang ?? "", + autoTaggingEnabled: settings.autoTaggingEnabled, + autoSummarizationEnabled: settings.autoSummarizationEnabled, + } + : undefined, + }); + + const showAutoTagging = clientConfig.inference.enableAutoTagging; + const showAutoSummarization = clientConfig.inference.enableAutoSummarization; + + const onSubmit = (data: z.infer<typeof zUpdateUserSettingsSchema>) => { + updateSettings(data); + }; + + return ( + <SettingsSection title="AI preferences"> + <form onSubmit={form.handleSubmit(onSubmit)}> + <FieldGroup className="gap-3"> + <Controller + name="inferredTagLang" + control={form.control} + render={({ field, fieldState }) => ( + <Field + className="rounded-lg border p-3" + data-invalid={fieldState.invalid} + > + <FieldContent> + <FieldLabel htmlFor="inferredTagLang"> + {t("settings.ai.inference_language")} + </FieldLabel> + <FieldDescription> + {t("settings.ai.inference_language_description")} + </FieldDescription> + </FieldContent> + <Input + {...field} + id="inferredTagLang" + value={field.value ?? ""} + onChange={(e) => + field.onChange( + e.target.value.length > 0 ? e.target.value : null, + ) + } + aria-invalid={fieldState.invalid} + placeholder={`Default (${clientConfig.inference.inferredTagLang})`} + type="text" + /> + {fieldState.invalid && ( + <FieldError errors={[fieldState.error]} /> + )} + </Field> + )} + /> + + {showAutoTagging && ( + <Controller + name="autoTaggingEnabled" + control={form.control} + render={({ field, fieldState }) => ( + <Field + orientation="horizontal" + className="rounded-lg border p-3" + data-invalid={fieldState.invalid} + > + <FieldContent> + <FieldLabel htmlFor="autoTaggingEnabled"> + {t("settings.ai.auto_tagging")} + </FieldLabel> + <FieldDescription> + {t("settings.ai.auto_tagging_description")} + </FieldDescription> + </FieldContent> + <Switch + id="autoTaggingEnabled" + name={field.name} + checked={field.value ?? true} + onCheckedChange={field.onChange} + aria-invalid={fieldState.invalid} + /> + {fieldState.invalid && ( + <FieldError errors={[fieldState.error]} /> + )} + </Field> + )} + /> + )} + + {showAutoSummarization && ( + <Controller + name="autoSummarizationEnabled" + control={form.control} + render={({ field, fieldState }) => ( + <Field + orientation="horizontal" + className="rounded-lg border p-3" + data-invalid={fieldState.invalid} + > + <FieldContent> + <FieldLabel htmlFor="autoSummarizationEnabled"> + {t("settings.ai.auto_summarization")} + </FieldLabel> + <FieldDescription> + {t("settings.ai.auto_summarization_description")} + </FieldDescription> + </FieldContent> + <Switch + id="autoSummarizationEnabled" + name={field.name} + checked={field.value ?? true} + onCheckedChange={field.onChange} + aria-invalid={fieldState.invalid} + /> + {fieldState.invalid && ( + <FieldError errors={[fieldState.error]} /> + )} + </Field> + )} + /> + )} + + <div className="flex justify-end pt-4"> + <ActionButton type="submit" loading={isPending} variant="default"> + <Save className="mr-2 size-4" /> + {t("actions.save")} + </ActionButton> + </div> + </FieldGroup> + </form> + </SettingsSection> + ); +} + +export function TagStyleSelector() { + const { t } = useTranslation(); + const settings = useUserSettings(); + + const { mutate: updateSettings, isPending: isUpdating } = + useUpdateUserSettings({ + onSuccess: () => { + toast({ + description: "Tag style updated successfully!", + }); + }, + onError: () => { + toast({ + description: "Failed to update tag style", + variant: "destructive", + }); + }, + }); + + const tagStyleOptions = [ + { + value: "lowercase-hyphens", + label: t("settings.ai.lowercase_hyphens"), + examples: ["machine-learning", "web-development"], + }, + { + value: "lowercase-spaces", + label: t("settings.ai.lowercase_spaces"), + examples: ["machine learning", "web development"], + }, + { + value: "lowercase-underscores", + label: t("settings.ai.lowercase_underscores"), + examples: ["machine_learning", "web_development"], + }, + { + value: "titlecase-spaces", + label: t("settings.ai.titlecase_spaces"), + examples: ["Machine Learning", "Web Development"], + }, + { + value: "titlecase-hyphens", + label: t("settings.ai.titlecase_hyphens"), + examples: ["Machine-Learning", "Web-Development"], + }, + { + value: "camelCase", + label: t("settings.ai.camelCase"), + examples: ["machineLearning", "webDevelopment"], + }, + { + value: "as-generated", + label: t("settings.ai.no_preference"), + examples: ["Machine Learning", "web development", "AI_generated"], + }, + ] as const; + + const selectedStyle = settings?.tagStyle ?? "as-generated"; + + return ( + <SettingsSection + title={t("settings.ai.tag_style")} + description={t("settings.ai.tag_style_description")} + > + <RadioGroup + value={selectedStyle} + onValueChange={(value) => { + updateSettings({ tagStyle: value as typeof selectedStyle }); + }} + disabled={isUpdating} + className="grid gap-3 sm:grid-cols-2" + > + {tagStyleOptions.map((option) => ( + <FieldLabel + key={option.value} + htmlFor={option.value} + className={cn(selectedStyle === option.value && "ring-1")} + > + <Field orientation="horizontal"> + <FieldContent> + <FieldTitle>{option.label}</FieldTitle> + <div className="flex flex-wrap gap-1"> + {option.examples.map((example) => ( + <Badge + key={example} + variant="secondary" + className="text-xs font-light" + > + {example} + </Badge> + ))} + </div> + </FieldContent> + <RadioGroupItem value={option.value} id={option.value} /> + </Field> + </FieldLabel> + ))} + </RadioGroup> + </SettingsSection> + ); +} + +export function CuratedTagsSelector() { + const api = useTRPC(); + const { t } = useTranslation(); + const settings = useUserSettings(); + + const { mutate: updateSettings, isPending: isUpdatingCuratedTags } = + useUpdateUserSettings({ + onSuccess: () => { + toast({ + description: t("settings.ai.curated_tags_updated"), + }); + }, + onError: () => { + toast({ + description: t("settings.ai.curated_tags_update_failed"), + variant: "destructive", + }); + }, + }); + + const areTagIdsEqual = React.useCallback((a: string[], b: string[]) => { + return a.length === b.length && a.every((id, index) => id === b[index]); + }, []); + + const curatedTagIds = React.useMemo( + () => settings?.curatedTagIds ?? [], + [settings?.curatedTagIds], + ); + const [localCuratedTagIds, setLocalCuratedTagIds] = + React.useState<string[]>(curatedTagIds); + const debouncedCuratedTagIds = useDebounce(localCuratedTagIds, 300); + const lastServerCuratedTagIdsRef = React.useRef(curatedTagIds); + const lastSubmittedCuratedTagIdsRef = React.useRef<string[] | null>(null); + + React.useEffect(() => { + const hadUnsyncedLocalChanges = !areTagIdsEqual( + localCuratedTagIds, + lastServerCuratedTagIdsRef.current, + ); + + if ( + !hadUnsyncedLocalChanges && + !areTagIdsEqual(localCuratedTagIds, curatedTagIds) + ) { + setLocalCuratedTagIds(curatedTagIds); + } + + lastServerCuratedTagIdsRef.current = curatedTagIds; + }, [areTagIdsEqual, curatedTagIds, localCuratedTagIds]); + + React.useEffect(() => { + if (isUpdatingCuratedTags) { + return; + } + + if (areTagIdsEqual(debouncedCuratedTagIds, curatedTagIds)) { + lastSubmittedCuratedTagIdsRef.current = null; + return; + } + + if ( + lastSubmittedCuratedTagIdsRef.current && + areTagIdsEqual( + lastSubmittedCuratedTagIdsRef.current, + debouncedCuratedTagIds, + ) + ) { + return; + } + + lastSubmittedCuratedTagIdsRef.current = debouncedCuratedTagIds; + updateSettings({ + curatedTagIds: + debouncedCuratedTagIds.length > 0 ? debouncedCuratedTagIds : null, + }); + }, [ + areTagIdsEqual, + curatedTagIds, + debouncedCuratedTagIds, + isUpdatingCuratedTags, + updateSettings, + ]); + + // Fetch selected tags to display their names + const { data: selectedTagsData } = useQuery( + api.tags.list.queryOptions( + { ids: localCuratedTagIds }, + { enabled: localCuratedTagIds.length > 0 }, + ), + ); + + const selectedTags: ZBookmarkTags[] = React.useMemo(() => { + const tagsMap = new Map( + (selectedTagsData?.tags ?? []).map((tag) => [tag.id, tag]), + ); + // Preserve the order from curatedTagIds instead of server sort order + return localCuratedTagIds + .map((id) => tagsMap.get(id)) + .filter((tag): tag is NonNullable<typeof tag> => tag != null) + .map((tag) => ({ + id: tag.id, + name: tag.name, + attachedBy: "human" as const, + })); + }, [selectedTagsData?.tags, localCuratedTagIds]); + + return ( + <SettingsSection + title={t("settings.ai.curated_tags")} + description={t("settings.ai.curated_tags_description")} + > + <TagsEditor + tags={selectedTags} + placeholder="Select curated tags..." + onAttach={(tag) => { + const tagId = tag.tagId; + if (tagId) { + setLocalCuratedTagIds((prev) => { + if (prev.includes(tagId)) { + return prev; + } + return [...prev, tagId]; + }); + } + }} + onDetach={(tag) => { + setLocalCuratedTagIds((prev) => { + return prev.filter((id) => id !== tag.tagId); + }); + }} + allowCreation={false} + /> + </SettingsSection> + ); +} export function PromptEditor() { + const api = useTRPC(); const { t } = useTranslation(); - const apiUtils = api.useUtils(); + const queryClient = useQueryClient(); const form = useForm<z.infer<typeof zNewPromptSchema>>({ resolver: zodResolver(zNewPromptSchema), @@ -50,15 +493,16 @@ export function PromptEditor() { }, }); - const { mutateAsync: createPrompt, isPending: isCreating } = - api.prompts.create.useMutation({ + const { mutateAsync: createPrompt, isPending: isCreating } = useMutation( + api.prompts.create.mutationOptions({ onSuccess: () => { toast({ description: "Prompt has been created!", }); - apiUtils.prompts.list.invalidate(); + queryClient.invalidateQueries(api.prompts.list.pathFilter()); }, - }); + }), + ); return ( <Form {...form}> @@ -140,26 +584,29 @@ export function PromptEditor() { } export function PromptRow({ prompt }: { prompt: ZPrompt }) { + const api = useTRPC(); const { t } = useTranslation(); - const apiUtils = api.useUtils(); - const { mutateAsync: updatePrompt, isPending: isUpdating } = - api.prompts.update.useMutation({ + const queryClient = useQueryClient(); + const { mutateAsync: updatePrompt, isPending: isUpdating } = useMutation( + api.prompts.update.mutationOptions({ onSuccess: () => { toast({ description: "Prompt has been updated!", }); - apiUtils.prompts.list.invalidate(); + queryClient.invalidateQueries(api.prompts.list.pathFilter()); }, - }); - const { mutate: deletePrompt, isPending: isDeleting } = - api.prompts.delete.useMutation({ + }), + ); + const { mutate: deletePrompt, isPending: isDeleting } = useMutation( + api.prompts.delete.mutationOptions({ onSuccess: () => { toast({ description: "Prompt has been deleted!", }); - apiUtils.prompts.list.invalidate(); + queryClient.invalidateQueries(api.prompts.list.pathFilter()); }, - }); + }), + ); const form = useForm<z.infer<typeof zUpdatePromptSchema>>({ resolver: zodResolver(zUpdatePromptSchema), @@ -273,92 +720,144 @@ export function PromptRow({ prompt }: { prompt: ZPrompt }) { } export function TaggingRules() { + const api = useTRPC(); const { t } = useTranslation(); - const { data: prompts, isLoading } = api.prompts.list.useQuery(); + const { data: prompts, isLoading } = useQuery( + api.prompts.list.queryOptions(), + ); return ( - <div className="mt-2 flex flex-col gap-2"> - <div className="w-full text-xl font-medium sm:w-1/3"> - {t("settings.ai.tagging_rules")} - </div> - <p className="mb-1 text-xs italic text-muted-foreground"> - {t("settings.ai.tagging_rule_description")} - </p> - {isLoading && <FullPageSpinner />} + <SettingsSection + title={t("settings.ai.tagging_rules")} + description={t("settings.ai.tagging_rule_description")} + > {prompts && prompts.length == 0 && ( - <p className="rounded-md bg-muted p-2 text-sm text-muted-foreground"> - You don't have any custom prompts yet. - </p> + <div className="flex items-start gap-2 rounded-md bg-muted p-4 text-sm text-muted-foreground"> + <Info className="size-4 flex-shrink-0" /> + <p>You don't have any custom prompts yet.</p> + </div> )} - {prompts && - prompts.map((prompt) => <PromptRow key={prompt.id} prompt={prompt} />)} - <PromptEditor /> - </div> + <div className="flex flex-col gap-2"> + {isLoading && <FullPageSpinner />} + {prompts && + prompts.map((prompt) => ( + <PromptRow key={prompt.id} prompt={prompt} /> + ))} + <PromptEditor /> + </div> + </SettingsSection> ); } export function PromptDemo() { + const api = useTRPC(); const { t } = useTranslation(); - const { data: prompts } = api.prompts.list.useQuery(); + const { data: prompts } = useQuery(api.prompts.list.queryOptions()); + const settings = useUserSettings(); const clientConfig = useClientConfig(); + const tagStyle = settings?.tagStyle ?? "as-generated"; + const curatedTagIds = settings?.curatedTagIds ?? []; + const { data: tagsData } = useQuery( + api.tags.list.queryOptions( + { ids: curatedTagIds }, + { enabled: curatedTagIds.length > 0 }, + ), + ); + const inferredTagLang = + settings?.inferredTagLang ?? clientConfig.inference.inferredTagLang; + + // Resolve curated tag names for preview + const curatedTagNames = + curatedTagIds.length > 0 && tagsData?.tags + ? curatedTagIds + .map((id) => tagsData.tags.find((tag) => tag.id === id)?.name) + .filter((name): name is string => Boolean(name)) + : undefined; + return ( - <div className="flex flex-col gap-2"> - <div className="mb-4 w-full text-xl font-medium sm:w-1/3"> - {t("settings.ai.prompt_preview")} + <SettingsSection + title={t("settings.ai.prompt_preview")} + description="Preview the actual prompts sent to AI based on your settings" + > + <div className="space-y-4"> + <div> + <p className="mb-2 text-sm font-medium"> + {t("settings.ai.text_prompt")} + </p> + <code className="block whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> + {buildTextPromptUntruncated( + inferredTagLang, + (prompts ?? []) + .filter( + (p) => p.appliesTo == "text" || p.appliesTo == "all_tagging", + ) + .map((p) => p.text), + "\n<CONTENT_HERE>\n", + tagStyle, + curatedTagNames, + ).trim()} + </code> + </div> + <div> + <p className="mb-2 text-sm font-medium"> + {t("settings.ai.images_prompt")} + </p> + <code className="block whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> + {buildImagePrompt( + inferredTagLang, + (prompts ?? []) + .filter( + (p) => + p.appliesTo == "images" || p.appliesTo == "all_tagging", + ) + .map((p) => p.text), + tagStyle, + curatedTagNames, + ).trim()} + </code> + </div> + <div> + <p className="mb-2 text-sm font-medium"> + {t("settings.ai.summarization_prompt")} + </p> + <code className="block whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> + {buildSummaryPromptUntruncated( + inferredTagLang, + (prompts ?? []) + .filter((p) => p.appliesTo == "summary") + .map((p) => p.text), + "\n<CONTENT_HERE>\n", + ).trim()} + </code> + </div> </div> - <p>{t("settings.ai.text_prompt")}</p> - <code className="whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> - {buildTextPromptUntruncated( - clientConfig.inference.inferredTagLang, - (prompts ?? []) - .filter( - (p) => p.appliesTo == "text" || p.appliesTo == "all_tagging", - ) - .map((p) => p.text), - "\n<CONTENT_HERE>\n", - ).trim()} - </code> - <p>{t("settings.ai.images_prompt")}</p> - <code className="whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> - {buildImagePrompt( - clientConfig.inference.inferredTagLang, - (prompts ?? []) - .filter( - (p) => p.appliesTo == "images" || p.appliesTo == "all_tagging", - ) - .map((p) => p.text), - ).trim()} - </code> - <p>{t("settings.ai.summarization_prompt")}</p> - <code className="whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> - {buildSummaryPromptUntruncated( - clientConfig.inference.inferredTagLang, - (prompts ?? []) - .filter((p) => p.appliesTo == "summary") - .map((p) => p.text), - "\n<CONTENT_HERE>\n", - ).trim()} - </code> - </div> + </SettingsSection> ); } export default function AISettings() { const { t } = useTranslation(); return ( - <> - <div className="rounded-md border bg-background p-4"> - <div className="mb-2 flex flex-col gap-3"> - <div className="w-full text-2xl font-medium sm:w-1/3"> - {t("settings.ai.ai_settings")} - </div> - <TaggingRules /> - </div> - </div> - <div className="mt-4 rounded-md border bg-background p-4"> - <PromptDemo /> - </div> - </> + <div className="space-y-6"> + <h2 className="text-3xl font-bold tracking-tight"> + {t("settings.ai.ai_settings")} + </h2> + + {/* AI Preferences */} + <AIPreferences /> + + {/* Tag Style */} + <TagStyleSelector /> + + {/* Curated Tags */} + <CuratedTagsSelector /> + + {/* Tagging Rules */} + <TaggingRules /> + + {/* Prompt Preview */} + <PromptDemo /> + </div> ); } |
