aboutsummaryrefslogtreecommitdiffstats
path: root/apps/workers
diff options
context:
space:
mode:
Diffstat (limited to 'apps/workers')
-rw-r--r--apps/workers/workers/inference/summarize.ts3
-rw-r--r--apps/workers/workers/inference/tagging.ts41
2 files changed, 37 insertions, 7 deletions
diff --git a/apps/workers/workers/inference/summarize.ts b/apps/workers/workers/inference/summarize.ts
index 460c3328..560bb5a2 100644
--- a/apps/workers/workers/inference/summarize.ts
+++ b/apps/workers/workers/inference/summarize.ts
@@ -61,6 +61,7 @@ export async function runSummarization(
where: eq(users.id, bookmarkData.userId),
columns: {
autoSummarizationEnabled: true,
+ inferredTagLang: true,
},
});
@@ -121,7 +122,7 @@ URL: ${link.url ?? ""}
});
const summaryPrompt = await buildSummaryPrompt(
- serverConfig.inference.inferredTagLang,
+ userSettings?.inferredTagLang ?? serverConfig.inference.inferredTagLang,
prompts.map((p) => p.text),
textToSummarize,
serverConfig.inference.contextLength,
diff --git a/apps/workers/workers/inference/tagging.ts b/apps/workers/workers/inference/tagging.ts
index 6d20b953..ace426a1 100644
--- a/apps/workers/workers/inference/tagging.ts
+++ b/apps/workers/workers/inference/tagging.ts
@@ -7,6 +7,7 @@ import type {
InferenceClient,
InferenceResponse,
} from "@karakeep/shared/inference";
+import type { ZTagStyle } from "@karakeep/shared/types/users";
import { db } from "@karakeep/db";
import {
bookmarks,
@@ -79,6 +80,8 @@ function tagNormalizer() {
}
async function buildPrompt(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
+ tagStyle: ZTagStyle,
+ inferredTagLang: string,
): Promise<string | null> {
const prompts = await fetchCustomPrompts(bookmark.userId, "text");
if (bookmark.link) {
@@ -96,22 +99,24 @@ async function buildPrompt(
return null;
}
return await buildTextPrompt(
- serverConfig.inference.inferredTagLang,
+ inferredTagLang,
prompts,
`URL: ${bookmark.link.url}
Title: ${bookmark.link.title ?? ""}
Description: ${bookmark.link.description ?? ""}
Content: ${content ?? ""}`,
serverConfig.inference.contextLength,
+ tagStyle,
);
}
if (bookmark.text) {
return await buildTextPrompt(
- serverConfig.inference.inferredTagLang,
+ inferredTagLang,
prompts,
bookmark.text.text ?? "",
serverConfig.inference.contextLength,
+ tagStyle,
);
}
@@ -123,6 +128,8 @@ async function inferTagsFromImage(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
inferenceClient: InferenceClient,
abortSignal: AbortSignal,
+ tagStyle: ZTagStyle,
+ inferredTagLang: string,
): Promise<InferenceResponse | null> {
const { asset, metadata } = await readAsset({
userId: bookmark.userId,
@@ -144,8 +151,9 @@ async function inferTagsFromImage(
const base64 = asset.toString("base64");
return inferenceClient.inferFromImage(
buildImagePrompt(
- serverConfig.inference.inferredTagLang,
+ inferredTagLang,
await fetchCustomPrompts(bookmark.userId, "images"),
+ tagStyle,
),
metadata.contentType,
base64,
@@ -215,12 +223,15 @@ async function inferTagsFromPDF(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
inferenceClient: InferenceClient,
abortSignal: AbortSignal,
+ tagStyle: ZTagStyle,
+ inferredTagLang: string,
) {
const prompt = await buildTextPrompt(
- serverConfig.inference.inferredTagLang,
+ inferredTagLang,
await fetchCustomPrompts(bookmark.userId, "text"),
`Content: ${bookmark.asset.content}`,
serverConfig.inference.contextLength,
+ tagStyle,
);
return inferenceClient.inferFromText(prompt, {
schema: openAIResponseSchema,
@@ -232,8 +243,10 @@ async function inferTagsFromText(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
inferenceClient: InferenceClient,
abortSignal: AbortSignal,
+ tagStyle: ZTagStyle,
+ inferredTagLang: string,
) {
- const prompt = await buildPrompt(bookmark);
+ const prompt = await buildPrompt(bookmark, tagStyle, inferredTagLang);
if (!prompt) {
return null;
}
@@ -248,10 +261,18 @@ async function inferTags(
bookmark: NonNullable<Awaited<ReturnType<typeof fetchBookmark>>>,
inferenceClient: InferenceClient,
abortSignal: AbortSignal,
+ tagStyle: ZTagStyle,
+ inferredTagLang: string,
) {
let response: InferenceResponse | null;
if (bookmark.link || bookmark.text) {
- response = await inferTagsFromText(bookmark, inferenceClient, abortSignal);
+ response = await inferTagsFromText(
+ bookmark,
+ inferenceClient,
+ abortSignal,
+ tagStyle,
+ inferredTagLang,
+ );
} else if (bookmark.asset) {
switch (bookmark.asset.assetType) {
case "image":
@@ -260,6 +281,8 @@ async function inferTags(
bookmark,
inferenceClient,
abortSignal,
+ tagStyle,
+ inferredTagLang,
);
break;
case "pdf":
@@ -268,6 +291,8 @@ async function inferTags(
bookmark,
inferenceClient,
abortSignal,
+ tagStyle,
+ inferredTagLang,
);
break;
default:
@@ -443,6 +468,8 @@ export async function runTagging(
where: eq(users.id, bookmark.userId),
columns: {
autoTaggingEnabled: true,
+ tagStyle: true,
+ inferredTagLang: true,
},
});
@@ -462,6 +489,8 @@ export async function runTagging(
bookmark,
inferenceClient,
job.abortSignal,
+ userSettings?.tagStyle ?? "as-generated",
+ userSettings?.inferredTagLang ?? serverConfig.inference.inferredTagLang,
);
if (tags === null) {