diff options
Diffstat (limited to 'apps')
| -rw-r--r-- | apps/web/components/settings/AISettings.tsx | 465 | ||||
| -rw-r--r-- | apps/web/components/ui/field.tsx | 244 | ||||
| -rw-r--r-- | apps/web/components/ui/radio-group.tsx | 43 | ||||
| -rw-r--r-- | apps/web/lib/i18n/locales/en/translation.json | 14 | ||||
| -rw-r--r-- | apps/web/lib/userSettings.tsx | 2 | ||||
| -rw-r--r-- | apps/web/package.json | 1 | ||||
| -rw-r--r-- | apps/workers/workers/inference/summarize.ts | 3 | ||||
| -rw-r--r-- | apps/workers/workers/inference/tagging.ts | 41 |
8 files changed, 682 insertions, 131 deletions
diff --git a/apps/web/components/settings/AISettings.tsx b/apps/web/components/settings/AISettings.tsx index d8adcb76..48c45633 100644 --- a/apps/web/components/settings/AISettings.tsx +++ b/apps/web/components/settings/AISettings.tsx @@ -1,17 +1,33 @@ "use client"; import { ActionButton } from "@/components/ui/action-button"; +import { Badge } from "@/components/ui/badge"; +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { + Field, + FieldContent, + FieldDescription, + FieldError, + FieldGroup, + FieldLabel, + FieldTitle, +} from "@/components/ui/field"; import { Form, FormControl, - FormDescription, FormField, FormItem, - FormLabel, FormMessage, } from "@/components/ui/form"; import { FullPageSpinner } from "@/components/ui/full-page-spinner"; import { Input } from "@/components/ui/input"; +import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; import { Select, SelectContent, @@ -26,9 +42,10 @@ import { useClientConfig } from "@/lib/clientConfig"; import { useTranslation } from "@/lib/i18n/client"; import { api } from "@/lib/trpc"; import { useUserSettings } from "@/lib/userSettings"; +import { cn } from "@/lib/utils"; import { zodResolver } from "@hookform/resolvers/zod"; -import { Plus, Save, Trash2 } from "lucide-react"; -import { useForm } from "react-hook-form"; +import { Info, Plus, Save, Trash2 } from "lucide-react"; +import { Controller, useForm } from "react-hook-form"; import { z } from "zod"; import { useUpdateUserSettings } from "@karakeep/shared-react/hooks/users"; @@ -44,12 +61,33 @@ import { } from "@karakeep/shared/types/prompts"; import { zUpdateUserSettingsSchema } from "@karakeep/shared/types/users"; +function SettingsSection({ + title, + description, + children, +}: { + title?: string; + description?: string; + children: React.ReactNode; + className?: string; +}) { + return ( + <Card> + <CardHeader> + {title && <CardTitle>{title}</CardTitle>} + {description && <CardDescription>{description}</CardDescription>} + </CardHeader> + <CardContent>{children}</CardContent> + </Card> + ); +} + export function AIPreferences() { const { t } = useTranslation(); const clientConfig = useClientConfig(); const settings = useUserSettings(); - const { mutate: updateSettings } = useUpdateUserSettings({ + const { mutate: updateSettings, isPending } = useUpdateUserSettings({ onSuccess: () => { toast({ description: "Settings updated successfully!", @@ -67,6 +105,7 @@ export function AIPreferences() { resolver: zodResolver(zUpdateUserSettingsSchema), values: settings ? { + inferredTagLang: settings.inferredTagLang ?? "", autoTaggingEnabled: settings.autoTaggingEnabled, autoSummarizationEnabled: settings.autoSummarizationEnabled, } @@ -76,72 +115,227 @@ export function AIPreferences() { const showAutoTagging = clientConfig.inference.enableAutoTagging; const showAutoSummarization = clientConfig.inference.enableAutoSummarization; - // Don't show the section if neither feature is enabled on the server - if (!showAutoTagging && !showAutoSummarization) { - return null; - } + const onSubmit = (data: z.infer<typeof zUpdateUserSettingsSchema>) => { + updateSettings(data); + }; return ( - <div className="mt-2 flex flex-col gap-2"> - <p className="mb-1 text-xs italic text-muted-foreground"> - {t("settings.ai.ai_preferences_description")} - </p> - <Form {...form}> - <form className="space-y-4"> + <SettingsSection title="AI preferences"> + <form onSubmit={form.handleSubmit(onSubmit)}> + <FieldGroup className="gap-3"> + <Controller + name="inferredTagLang" + control={form.control} + render={({ field, fieldState }) => ( + <Field + className="rounded-lg border p-3" + data-invalid={fieldState.invalid} + > + <FieldContent> + <FieldLabel htmlFor="inferredTagLang"> + {t("settings.ai.inference_language")} + </FieldLabel> + <FieldDescription> + {t("settings.ai.inference_language_description")} + </FieldDescription> + </FieldContent> + <Input + {...field} + id="inferredTagLang" + value={field.value ?? ""} + onChange={(e) => + field.onChange( + e.target.value.length > 0 ? e.target.value : null, + ) + } + aria-invalid={fieldState.invalid} + placeholder={`Default (${clientConfig.inference.inferredTagLang})`} + type="text" + /> + {fieldState.invalid && ( + <FieldError errors={[fieldState.error]} /> + )} + </Field> + )} + /> + {showAutoTagging && ( - <FormField - control={form.control} + <Controller name="autoTaggingEnabled" - render={({ field }) => ( - <FormItem className="flex flex-row items-center justify-between rounded-lg border p-3"> - <div className="space-y-0.5"> - <FormLabel>{t("settings.ai.auto_tagging")}</FormLabel> - <FormDescription> + control={form.control} + render={({ field, fieldState }) => ( + <Field + orientation="horizontal" + className="rounded-lg border p-3" + data-invalid={fieldState.invalid} + > + <FieldContent> + <FieldLabel htmlFor="autoTaggingEnabled"> + {t("settings.ai.auto_tagging")} + </FieldLabel> + <FieldDescription> {t("settings.ai.auto_tagging_description")} - </FormDescription> - </div> - <FormControl> - <Switch - checked={field.value ?? true} - onCheckedChange={(checked) => { - field.onChange(checked); - updateSettings({ autoTaggingEnabled: checked }); - }} - /> - </FormControl> - </FormItem> + </FieldDescription> + </FieldContent> + <Switch + id="autoTaggingEnabled" + name={field.name} + checked={field.value ?? true} + onCheckedChange={field.onChange} + aria-invalid={fieldState.invalid} + /> + {fieldState.invalid && ( + <FieldError errors={[fieldState.error]} /> + )} + </Field> )} /> )} {showAutoSummarization && ( - <FormField - control={form.control} + <Controller name="autoSummarizationEnabled" - render={({ field }) => ( - <FormItem className="flex flex-row items-center justify-between rounded-lg border p-3"> - <div className="space-y-0.5"> - <FormLabel>{t("settings.ai.auto_summarization")}</FormLabel> - <FormDescription> + control={form.control} + render={({ field, fieldState }) => ( + <Field + orientation="horizontal" + className="rounded-lg border p-3" + data-invalid={fieldState.invalid} + > + <FieldContent> + <FieldLabel htmlFor="autoSummarizationEnabled"> + {t("settings.ai.auto_summarization")} + </FieldLabel> + <FieldDescription> {t("settings.ai.auto_summarization_description")} - </FormDescription> - </div> - <FormControl> - <Switch - checked={field.value ?? true} - onCheckedChange={(checked) => { - field.onChange(checked); - updateSettings({ autoSummarizationEnabled: checked }); - }} - /> - </FormControl> - </FormItem> + </FieldDescription> + </FieldContent> + <Switch + id="autoSummarizationEnabled" + name={field.name} + checked={field.value ?? true} + onCheckedChange={field.onChange} + aria-invalid={fieldState.invalid} + /> + {fieldState.invalid && ( + <FieldError errors={[fieldState.error]} /> + )} + </Field> )} /> )} - </form> - </Form> - </div> + + <div className="flex justify-end pt-4"> + <ActionButton type="submit" loading={isPending} variant="default"> + <Save className="mr-2 size-4" /> + {t("actions.save")} + </ActionButton> + </div> + </FieldGroup> + </form> + </SettingsSection> + ); +} + +export function TagStyleSelector() { + const { t } = useTranslation(); + const settings = useUserSettings(); + + const { mutate: updateSettings, isPending: isUpdating } = + useUpdateUserSettings({ + onSuccess: () => { + toast({ + description: "Tag style updated successfully!", + }); + }, + onError: () => { + toast({ + description: "Failed to update tag style", + variant: "destructive", + }); + }, + }); + + const tagStyleOptions = [ + { + value: "lowercase-hyphens", + label: t("settings.ai.lowercase_hyphens"), + examples: ["machine-learning", "web-development"], + }, + { + value: "lowercase-spaces", + label: t("settings.ai.lowercase_spaces"), + examples: ["machine learning", "web development"], + }, + { + value: "lowercase-underscores", + label: t("settings.ai.lowercase_underscores"), + examples: ["machine_learning", "web_development"], + }, + { + value: "titlecase-spaces", + label: t("settings.ai.titlecase_spaces"), + examples: ["Machine Learning", "Web Development"], + }, + { + value: "titlecase-hyphens", + label: t("settings.ai.titlecase_hyphens"), + examples: ["Machine-Learning", "Web-Development"], + }, + { + value: "camelCase", + label: t("settings.ai.camelCase"), + examples: ["machineLearning", "webDevelopment"], + }, + { + value: "as-generated", + label: t("settings.ai.no_preference"), + examples: ["Machine Learning", "web development", "AI_generated"], + }, + ] as const; + + const selectedStyle = settings?.tagStyle ?? "as-generated"; + + return ( + <SettingsSection + title={t("settings.ai.tag_style")} + description={t("settings.ai.tag_style_description")} + > + <RadioGroup + value={selectedStyle} + onValueChange={(value) => { + updateSettings({ tagStyle: value as typeof selectedStyle }); + }} + disabled={isUpdating} + className="grid gap-3 sm:grid-cols-2" + > + {tagStyleOptions.map((option) => ( + <FieldLabel + key={option.value} + htmlFor={option.value} + className={cn(selectedStyle === option.value && "ring-1")} + > + <Field orientation="horizontal"> + <FieldContent> + <FieldTitle>{option.label}</FieldTitle> + <div className="flex flex-wrap gap-1"> + {option.examples.map((example) => ( + <Badge + key={example} + variant="secondary" + className="text-xs font-light" + > + {example} + </Badge> + ))} + </div> + </FieldContent> + <RadioGroupItem value={option.value} id={option.value} /> + </Field> + </FieldLabel> + ))} + </RadioGroup> + </SettingsSection> ); } @@ -384,89 +578,116 @@ export function TaggingRules() { const { data: prompts, isLoading } = api.prompts.list.useQuery(); return ( - <div className="mt-2 flex flex-col gap-2"> - <div className="w-full text-xl font-medium sm:w-1/3"> - {t("settings.ai.tagging_rules")} - </div> - <p className="mb-1 text-xs italic text-muted-foreground"> - {t("settings.ai.tagging_rule_description")} - </p> - {isLoading && <FullPageSpinner />} + <SettingsSection + title={t("settings.ai.tagging_rules")} + description={t("settings.ai.tagging_rule_description")} + > {prompts && prompts.length == 0 && ( - <p className="rounded-md bg-muted p-2 text-sm text-muted-foreground"> - You don't have any custom prompts yet. - </p> + <div className="flex items-start gap-2 rounded-md bg-muted p-4 text-sm text-muted-foreground"> + <Info className="size-4 flex-shrink-0" /> + <p>You don't have any custom prompts yet.</p> + </div> )} - {prompts && - prompts.map((prompt) => <PromptRow key={prompt.id} prompt={prompt} />)} - <PromptEditor /> - </div> + <div className="flex flex-col gap-2"> + {isLoading && <FullPageSpinner />} + {prompts && + prompts.map((prompt) => ( + <PromptRow key={prompt.id} prompt={prompt} /> + ))} + <PromptEditor /> + </div> + </SettingsSection> ); } export function PromptDemo() { const { t } = useTranslation(); const { data: prompts } = api.prompts.list.useQuery(); + const settings = useUserSettings(); const clientConfig = useClientConfig(); + const tagStyle = settings?.tagStyle ?? "as-generated"; + const inferredTagLang = + settings?.inferredTagLang ?? clientConfig.inference.inferredTagLang; + return ( - <div className="flex flex-col gap-2"> - <div className="mb-4 w-full text-xl font-medium sm:w-1/3"> - {t("settings.ai.prompt_preview")} + <SettingsSection + title={t("settings.ai.prompt_preview")} + description="Preview the actual prompts sent to AI based on your settings" + > + <div className="space-y-4"> + <div> + <p className="mb-2 text-sm font-medium"> + {t("settings.ai.text_prompt")} + </p> + <code className="block whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> + {buildTextPromptUntruncated( + inferredTagLang, + (prompts ?? []) + .filter( + (p) => p.appliesTo == "text" || p.appliesTo == "all_tagging", + ) + .map((p) => p.text), + "\n<CONTENT_HERE>\n", + tagStyle, + ).trim()} + </code> + </div> + <div> + <p className="mb-2 text-sm font-medium"> + {t("settings.ai.images_prompt")} + </p> + <code className="block whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> + {buildImagePrompt( + inferredTagLang, + (prompts ?? []) + .filter( + (p) => + p.appliesTo == "images" || p.appliesTo == "all_tagging", + ) + .map((p) => p.text), + tagStyle, + ).trim()} + </code> + </div> + <div> + <p className="mb-2 text-sm font-medium"> + {t("settings.ai.summarization_prompt")} + </p> + <code className="block whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> + {buildSummaryPromptUntruncated( + inferredTagLang, + (prompts ?? []) + .filter((p) => p.appliesTo == "summary") + .map((p) => p.text), + "\n<CONTENT_HERE>\n", + ).trim()} + </code> + </div> </div> - <p>{t("settings.ai.text_prompt")}</p> - <code className="whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> - {buildTextPromptUntruncated( - clientConfig.inference.inferredTagLang, - (prompts ?? []) - .filter( - (p) => p.appliesTo == "text" || p.appliesTo == "all_tagging", - ) - .map((p) => p.text), - "\n<CONTENT_HERE>\n", - ).trim()} - </code> - <p>{t("settings.ai.images_prompt")}</p> - <code className="whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> - {buildImagePrompt( - clientConfig.inference.inferredTagLang, - (prompts ?? []) - .filter( - (p) => p.appliesTo == "images" || p.appliesTo == "all_tagging", - ) - .map((p) => p.text), - ).trim()} - </code> - <p>{t("settings.ai.summarization_prompt")}</p> - <code className="whitespace-pre-wrap rounded-md bg-muted p-3 text-sm text-muted-foreground"> - {buildSummaryPromptUntruncated( - clientConfig.inference.inferredTagLang, - (prompts ?? []) - .filter((p) => p.appliesTo == "summary") - .map((p) => p.text), - "\n<CONTENT_HERE>\n", - ).trim()} - </code> - </div> + </SettingsSection> ); } export default function AISettings() { const { t } = useTranslation(); return ( - <> - <div className="rounded-md border bg-background p-4"> - <div className="mb-2 flex flex-col gap-3"> - <div className="w-full text-2xl font-medium sm:w-1/3"> - {t("settings.ai.ai_settings")} - </div> - <AIPreferences /> - <TaggingRules /> - </div> - </div> - <div className="mt-4 rounded-md border bg-background p-4"> - <PromptDemo /> - </div> - </> + <div className="space-y-6"> + <h2 className="text-3xl font-bold tracking-tight"> + {t("settings.ai.ai_settings")} + </h2> + + {/* AI Preferences */} + <AIPreferences /> + + {/* Tag Style */} + <TagStyleSelector /> + + {/* Tagging Rules */} + <TaggingRules /> + + {/* Prompt Preview */} + <PromptDemo /> + </div> ); } diff --git a/apps/web/components/ui/field.tsx b/apps/web/components/ui/field.tsx new file mode 100644 index 00000000..a52897f5 --- /dev/null +++ b/apps/web/components/ui/field.tsx @@ -0,0 +1,244 @@ +"use client"; + +import type { VariantProps } from "class-variance-authority"; +import { useMemo } from "react"; +import { Label } from "@/components/ui/label"; +import { Separator } from "@/components/ui/separator"; +import { cn } from "@/lib/utils"; +import { cva } from "class-variance-authority"; + +function FieldSet({ className, ...props }: React.ComponentProps<"fieldset">) { + return ( + <fieldset + data-slot="field-set" + className={cn( + "flex flex-col gap-6", + "has-[>[data-slot=checkbox-group]]:gap-3 has-[>[data-slot=radio-group]]:gap-3", + className, + )} + {...props} + /> + ); +} + +function FieldLegend({ + className, + variant = "legend", + ...props +}: React.ComponentProps<"legend"> & { variant?: "legend" | "label" }) { + return ( + <legend + data-slot="field-legend" + data-variant={variant} + className={cn( + "mb-3 font-medium", + "data-[variant=legend]:text-base", + "data-[variant=label]:text-sm", + className, + )} + {...props} + /> + ); +} + +function FieldGroup({ className, ...props }: React.ComponentProps<"div">) { + return ( + <div + data-slot="field-group" + className={cn( + "group/field-group @container/field-group flex w-full flex-col gap-7 data-[slot=checkbox-group]:gap-3 [&>[data-slot=field-group]]:gap-4", + className, + )} + {...props} + /> + ); +} + +const fieldVariants = cva( + "group/field flex w-full gap-3 data-[invalid=true]:text-destructive", + { + variants: { + orientation: { + vertical: ["flex-col [&>*]:w-full [&>.sr-only]:w-auto"], + horizontal: [ + "flex-row items-center", + "[&>[data-slot=field-label]]:flex-auto", + "has-[>[data-slot=field-content]]:[&>[role=checkbox],[role=radio]]:mt-px has-[>[data-slot=field-content]]:items-start", + ], + responsive: [ + "@md/field-group:flex-row @md/field-group:items-center @md/field-group:[&>*]:w-auto flex-col [&>*]:w-full [&>.sr-only]:w-auto", + "@md/field-group:[&>[data-slot=field-label]]:flex-auto", + "@md/field-group:has-[>[data-slot=field-content]]:items-start @md/field-group:has-[>[data-slot=field-content]]:[&>[role=checkbox],[role=radio]]:mt-px", + ], + }, + }, + defaultVariants: { + orientation: "vertical", + }, + }, +); + +function Field({ + className, + orientation = "vertical", + ...props +}: React.ComponentProps<"div"> & VariantProps<typeof fieldVariants>) { + return ( + <div + role="group" + data-slot="field" + data-orientation={orientation} + className={cn(fieldVariants({ orientation }), className)} + {...props} + /> + ); +} + +function FieldContent({ className, ...props }: React.ComponentProps<"div">) { + return ( + <div + data-slot="field-content" + className={cn( + "group/field-content flex flex-1 flex-col gap-1.5 leading-snug", + className, + )} + {...props} + /> + ); +} + +function FieldLabel({ + className, + ...props +}: React.ComponentProps<typeof Label>) { + return ( + <Label + data-slot="field-label" + className={cn( + "group/field-label peer/field-label flex w-fit gap-2 leading-snug group-data-[disabled=true]/field:opacity-50", + "has-[>[data-slot=field]]:w-full has-[>[data-slot=field]]:flex-col has-[>[data-slot=field]]:rounded-md has-[>[data-slot=field]]:border [&>[data-slot=field]]:p-4", + "has-data-[state=checked]:bg-primary/5 has-data-[state=checked]:border-primary dark:has-data-[state=checked]:bg-primary/10", + className, + )} + {...props} + /> + ); +} + +function FieldTitle({ className, ...props }: React.ComponentProps<"div">) { + return ( + <div + data-slot="field-label" + className={cn( + "flex w-fit items-center gap-2 text-sm font-medium leading-snug group-data-[disabled=true]/field:opacity-50", + className, + )} + {...props} + /> + ); +} + +function FieldDescription({ className, ...props }: React.ComponentProps<"p">) { + return ( + <p + data-slot="field-description" + className={cn( + "text-sm font-normal leading-normal text-muted-foreground group-has-[[data-orientation=horizontal]]/field:text-balance", + "nth-last-2:-mt-1 last:mt-0 [[data-variant=legend]+&]:-mt-1.5", + "[&>a:hover]:text-primary [&>a]:underline [&>a]:underline-offset-4", + className, + )} + {...props} + /> + ); +} + +function FieldSeparator({ + children, + className, + ...props +}: React.ComponentProps<"div"> & { + children?: React.ReactNode; +}) { + return ( + <div + data-slot="field-separator" + data-content={!!children} + className={cn( + "relative -my-2 h-5 text-sm group-data-[variant=outline]/field-group:-mb-2", + className, + )} + {...props} + > + <Separator className="absolute inset-0 top-1/2" /> + {children && ( + <span + className="relative mx-auto block w-fit bg-background px-2 text-muted-foreground" + data-slot="field-separator-content" + > + {children} + </span> + )} + </div> + ); +} + +function FieldError({ + className, + children, + errors, + ...props +}: React.ComponentProps<"div"> & { + errors?: ({ message?: string } | undefined)[]; +}) { + const content = useMemo(() => { + if (children) { + return children; + } + + if (!errors) { + return null; + } + + if (errors?.length === 1 && errors[0]?.message) { + return errors[0].message; + } + + return ( + <ul className="ml-4 flex list-disc flex-col gap-1"> + {errors.map( + (error, index) => + error?.message && <li key={index}>{error.message}</li>, + )} + </ul> + ); + }, [children, errors]); + + if (!content) { + return null; + } + + return ( + <div + role="alert" + data-slot="field-error" + className={cn("text-sm font-normal text-destructive", className)} + {...props} + > + {content} + </div> + ); +} + +export { + Field, + FieldLabel, + FieldDescription, + FieldError, + FieldGroup, + FieldLegend, + FieldSeparator, + FieldSet, + FieldContent, + FieldTitle, +}; diff --git a/apps/web/components/ui/radio-group.tsx b/apps/web/components/ui/radio-group.tsx new file mode 100644 index 00000000..0da1136e --- /dev/null +++ b/apps/web/components/ui/radio-group.tsx @@ -0,0 +1,43 @@ +"use client"; + +import * as React from "react"; +import { cn } from "@/lib/utils"; +import * as RadioGroupPrimitive from "@radix-ui/react-radio-group"; +import { Circle } from "lucide-react"; + +const RadioGroup = React.forwardRef< + React.ElementRef<typeof RadioGroupPrimitive.Root>, + React.ComponentPropsWithoutRef<typeof RadioGroupPrimitive.Root> +>(({ className, ...props }, ref) => { + return ( + <RadioGroupPrimitive.Root + className={cn("grid gap-2", className)} + {...props} + ref={ref} + /> + ); +}); +RadioGroup.displayName = RadioGroupPrimitive.Root.displayName; + +const RadioGroupItem = React.forwardRef< + React.ElementRef<typeof RadioGroupPrimitive.Item>, + React.ComponentPropsWithoutRef<typeof RadioGroupPrimitive.Item> +>(({ className, ...props }, ref) => { + return ( + <RadioGroupPrimitive.Item + ref={ref} + className={cn( + "aspect-square h-4 w-4 rounded-full border border-primary text-primary ring-offset-background focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50", + className, + )} + {...props} + > + <RadioGroupPrimitive.Indicator className="flex items-center justify-center"> + <Circle className="h-2.5 w-2.5 fill-current text-current" /> + </RadioGroupPrimitive.Indicator> + </RadioGroupPrimitive.Item> + ); +}); +RadioGroupItem.displayName = RadioGroupPrimitive.Item.displayName; + +export { RadioGroup, RadioGroupItem }; diff --git a/apps/web/lib/i18n/locales/en/translation.json b/apps/web/lib/i18n/locales/en/translation.json index af9b2748..03aaa645 100644 --- a/apps/web/lib/i18n/locales/en/translation.json +++ b/apps/web/lib/i18n/locales/en/translation.json @@ -236,7 +236,6 @@ }, "ai": { "ai_settings": "AI Settings", - "ai_preferences_description": "Control which AI features are enabled for your account.", "auto_tagging": "Auto-tagging", "auto_tagging_description": "Automatically generate tags for your bookmarks using AI.", "auto_summarization": "Auto-summarization", @@ -250,7 +249,18 @@ "all_tagging": "All Tagging", "text_tagging": "Text Tagging", "image_tagging": "Image Tagging", - "summarization": "Summarization" + "summarization": "Summarization", + "tag_style": "Tag Style", + "tag_style_description": "Choose how your auto-generated tags should be formatted.", + "lowercase_hyphens": "Lowercase with hyphens", + "lowercase_spaces": "Lowercase with spaces", + "lowercase_underscores": "Lowercase with underscores", + "titlecase_spaces": "Title case with spaces", + "titlecase_hyphens": "Title case with hyphens", + "camelCase": "camelCase", + "no_preference": "No preference", + "inference_language": "Inference Language", + "inference_language_description": "Choose language for AI-generated tags and summaries." }, "feeds": { "rss_subscriptions": "RSS Subscriptions", diff --git a/apps/web/lib/userSettings.tsx b/apps/web/lib/userSettings.tsx index d35c9e56..4789e2ba 100644 --- a/apps/web/lib/userSettings.tsx +++ b/apps/web/lib/userSettings.tsx @@ -18,6 +18,8 @@ export const UserSettingsContext = createContext<ZUserSettings>({ readerFontFamily: null, autoTaggingEnabled: null, autoSummarizationEnabled: null, + tagStyle: "as-generated", + inferredTagLang: null, }); export function UserSettingsContextProvider({ diff --git a/apps/web/package.json b/apps/web/package.json index 5f8eff0d..0400bc5c 100644 --- a/apps/web/package.json +++ b/apps/web/package.json @@ -40,6 +40,7 @@ "@radix-ui/react-label": "^2.1.7", "@radix-ui/react-popover": "^1.1.14", "@radix-ui/react-progress": "^1.1.7", + "@radix-ui/react-radio-group": "^1.3.8", "@radix-ui/react-scroll-area": "^1.2.9", "@radix-ui/react-select": "^2.2.5", "@radix-ui/react-separator": "^1.1.7", diff --git a/apps/workers/workers/inference/summarize.ts b/apps/workers/workers/inference/summarize.ts index 460c3328..560bb5a2 100644 --- a/apps/workers/workers/inference/summarize.ts +++ b/apps/workers/workers/inference/summarize.ts @@ -61,6 +61,7 @@ export async function runSummarization( where: eq(users.id, bookmarkData.userId), columns: { autoSummarizationEnabled: true, + inferredTagLang: true, }, }); @@ -121,7 +122,7 @@ URL: ${link.url ?? ""} }); const summaryPrompt = await buildSummaryPrompt( - serverConfig.inference.inferredTagLang, + userSettings?.inferredTagLang ?? serverConfig.inference.inferredTagLang, prompts.map((p) => p.text), textToSummarize, serverConfig.inference.contextLength, diff --git a/apps/workers/workers/inference/tagging.ts b/apps/workers/workers/inference/tagging.ts index 6d20b953..ace426a1 100644 --- a/apps/workers/workers/inference/tagging.ts +++ b/apps/workers/workers/inference/tagging.ts @@ -7,6 +7,7 @@ import type { InferenceClient, InferenceResponse, } from "@karakeep/shared/inference"; +import type { ZTagStyle } from "@karakeep/shared/types/users"; import { db } from "@karakeep/db"; import { bookmarks, @@ -79,6 +80,8 @@ function tagNormalizer() { } async function buildPrompt( bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, + tagStyle: ZTagStyle, + inferredTagLang: string, ): Promise<string | null> { const prompts = await fetchCustomPrompts(bookmark.userId, "text"); if (bookmark.link) { @@ -96,22 +99,24 @@ async function buildPrompt( return null; } return await buildTextPrompt( - serverConfig.inference.inferredTagLang, + inferredTagLang, prompts, `URL: ${bookmark.link.url} Title: ${bookmark.link.title ?? ""} Description: ${bookmark.link.description ?? ""} Content: ${content ?? ""}`, serverConfig.inference.contextLength, + tagStyle, ); } if (bookmark.text) { return await buildTextPrompt( - serverConfig.inference.inferredTagLang, + inferredTagLang, prompts, bookmark.text.text ?? "", serverConfig.inference.contextLength, + tagStyle, ); } @@ -123,6 +128,8 @@ async function inferTagsFromImage( bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, inferenceClient: InferenceClient, abortSignal: AbortSignal, + tagStyle: ZTagStyle, + inferredTagLang: string, ): Promise<InferenceResponse | null> { const { asset, metadata } = await readAsset({ userId: bookmark.userId, @@ -144,8 +151,9 @@ async function inferTagsFromImage( const base64 = asset.toString("base64"); return inferenceClient.inferFromImage( buildImagePrompt( - serverConfig.inference.inferredTagLang, + inferredTagLang, await fetchCustomPrompts(bookmark.userId, "images"), + tagStyle, ), metadata.contentType, base64, @@ -215,12 +223,15 @@ async function inferTagsFromPDF( bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, inferenceClient: InferenceClient, abortSignal: AbortSignal, + tagStyle: ZTagStyle, + inferredTagLang: string, ) { const prompt = await buildTextPrompt( - serverConfig.inference.inferredTagLang, + inferredTagLang, await fetchCustomPrompts(bookmark.userId, "text"), `Content: ${bookmark.asset.content}`, serverConfig.inference.contextLength, + tagStyle, ); return inferenceClient.inferFromText(prompt, { schema: openAIResponseSchema, @@ -232,8 +243,10 @@ async function inferTagsFromText( bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, inferenceClient: InferenceClient, abortSignal: AbortSignal, + tagStyle: ZTagStyle, + inferredTagLang: string, ) { - const prompt = await buildPrompt(bookmark); + const prompt = await buildPrompt(bookmark, tagStyle, inferredTagLang); if (!prompt) { return null; } @@ -248,10 +261,18 @@ async function inferTags( bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>, inferenceClient: InferenceClient, abortSignal: AbortSignal, + tagStyle: ZTagStyle, + inferredTagLang: string, ) { let response: InferenceResponse | null; if (bookmark.link || bookmark.text) { - response = await inferTagsFromText(bookmark, inferenceClient, abortSignal); + response = await inferTagsFromText( + bookmark, + inferenceClient, + abortSignal, + tagStyle, + inferredTagLang, + ); } else if (bookmark.asset) { switch (bookmark.asset.assetType) { case "image": @@ -260,6 +281,8 @@ async function inferTags( bookmark, inferenceClient, abortSignal, + tagStyle, + inferredTagLang, ); break; case "pdf": @@ -268,6 +291,8 @@ async function inferTags( bookmark, inferenceClient, abortSignal, + tagStyle, + inferredTagLang, ); break; default: @@ -443,6 +468,8 @@ export async function runTagging( where: eq(users.id, bookmark.userId), columns: { autoTaggingEnabled: true, + tagStyle: true, + inferredTagLang: true, }, }); @@ -462,6 +489,8 @@ export async function runTagging( bookmark, inferenceClient, job.abortSignal, + userSettings?.tagStyle ?? "as-generated", + userSettings?.inferredTagLang ?? serverConfig.inference.inferredTagLang, ); if (tags === null) { |
